1414#include <linux/mutex.h>
1515#include <linux/page_counter.h>
1616#include <linux/parser.h>
17+ #include <linux/refcount.h>
1718#include <linux/rculist.h>
1819#include <linux/slab.h>
1920
@@ -71,7 +72,9 @@ struct dmem_cgroup_pool_state {
7172 struct rcu_head rcu ;
7273
7374 struct page_counter cnt ;
75+ struct dmem_cgroup_pool_state * parent ;
7476
77+ refcount_t ref ;
7578 bool inited ;
7679};
7780
@@ -88,6 +91,9 @@ struct dmem_cgroup_pool_state {
8891static DEFINE_SPINLOCK (dmemcg_lock );
8992static LIST_HEAD (dmem_cgroup_regions );
9093
94+ static void dmemcg_free_region (struct kref * ref );
95+ static void dmemcg_pool_free_rcu (struct rcu_head * rcu );
96+
9197static inline struct dmemcg_state *
9298css_to_dmemcs (struct cgroup_subsys_state * css )
9399{
@@ -104,10 +110,38 @@ static struct dmemcg_state *parent_dmemcs(struct dmemcg_state *cg)
104110 return cg -> css .parent ? css_to_dmemcs (cg -> css .parent ) : NULL ;
105111}
106112
113+ static void dmemcg_pool_get (struct dmem_cgroup_pool_state * pool )
114+ {
115+ refcount_inc (& pool -> ref );
116+ }
117+
118+ static bool dmemcg_pool_tryget (struct dmem_cgroup_pool_state * pool )
119+ {
120+ return refcount_inc_not_zero (& pool -> ref );
121+ }
122+
123+ static void dmemcg_pool_put (struct dmem_cgroup_pool_state * pool )
124+ {
125+ if (!refcount_dec_and_test (& pool -> ref ))
126+ return ;
127+
128+ call_rcu (& pool -> rcu , dmemcg_pool_free_rcu );
129+ }
130+
131+ static void dmemcg_pool_free_rcu (struct rcu_head * rcu )
132+ {
133+ struct dmem_cgroup_pool_state * pool = container_of (rcu , typeof (* pool ), rcu );
134+
135+ if (pool -> parent )
136+ dmemcg_pool_put (pool -> parent );
137+ kref_put (& pool -> region -> ref , dmemcg_free_region );
138+ kfree (pool );
139+ }
140+
107141static void free_cg_pool (struct dmem_cgroup_pool_state * pool )
108142{
109143 list_del (& pool -> region_node );
110- kfree (pool );
144+ dmemcg_pool_put (pool );
111145}
112146
113147static void
@@ -342,6 +376,12 @@ alloc_pool_single(struct dmemcg_state *dmemcs, struct dmem_cgroup_region *region
342376 page_counter_init (& pool -> cnt ,
343377 ppool ? & ppool -> cnt : NULL , true);
344378 reset_all_resource_limits (pool );
379+ refcount_set (& pool -> ref , 1 );
380+ kref_get (& region -> ref );
381+ if (ppool && !pool -> parent ) {
382+ pool -> parent = ppool ;
383+ dmemcg_pool_get (ppool );
384+ }
345385
346386 list_add_tail_rcu (& pool -> css_node , & dmemcs -> pools );
347387 list_add_tail (& pool -> region_node , & region -> pools );
@@ -389,6 +429,10 @@ get_cg_pool_locked(struct dmemcg_state *dmemcs, struct dmem_cgroup_region *regio
389429
390430 /* Fix up parent links, mark as inited. */
391431 pool -> cnt .parent = & ppool -> cnt ;
432+ if (ppool && !pool -> parent ) {
433+ pool -> parent = ppool ;
434+ dmemcg_pool_get (ppool );
435+ }
392436 pool -> inited = true;
393437
394438 pool = ppool ;
@@ -423,7 +467,7 @@ static void dmemcg_free_region(struct kref *ref)
423467 */
424468void dmem_cgroup_unregister_region (struct dmem_cgroup_region * region )
425469{
426- struct list_head * entry ;
470+ struct dmem_cgroup_pool_state * pool , * next ;
427471
428472 if (!region )
429473 return ;
@@ -433,11 +477,10 @@ void dmem_cgroup_unregister_region(struct dmem_cgroup_region *region)
433477 /* Remove from global region list */
434478 list_del_rcu (& region -> region_node );
435479
436- list_for_each_rcu (entry , & region -> pools ) {
437- struct dmem_cgroup_pool_state * pool =
438- container_of (entry , typeof (* pool ), region_node );
439-
480+ list_for_each_entry_safe (pool , next , & region -> pools , region_node ) {
440481 list_del_rcu (& pool -> css_node );
482+ list_del (& pool -> region_node );
483+ dmemcg_pool_put (pool );
441484 }
442485
443486 /*
@@ -518,8 +561,10 @@ static struct dmem_cgroup_region *dmemcg_get_region_by_name(const char *name)
518561 */
519562void dmem_cgroup_pool_state_put (struct dmem_cgroup_pool_state * pool )
520563{
521- if (pool )
564+ if (pool ) {
522565 css_put (& pool -> cs -> css );
566+ dmemcg_pool_put (pool );
567+ }
523568}
524569EXPORT_SYMBOL_GPL (dmem_cgroup_pool_state_put );
525570
@@ -533,6 +578,8 @@ get_cg_pool_unlocked(struct dmemcg_state *cg, struct dmem_cgroup_region *region)
533578 pool = find_cg_pool_locked (cg , region );
534579 if (pool && !READ_ONCE (pool -> inited ))
535580 pool = NULL ;
581+ if (pool && !dmemcg_pool_tryget (pool ))
582+ pool = NULL ;
536583 rcu_read_unlock ();
537584
538585 while (!pool ) {
@@ -541,6 +588,8 @@ get_cg_pool_unlocked(struct dmemcg_state *cg, struct dmem_cgroup_region *region)
541588 pool = get_cg_pool_locked (cg , region , & allocpool );
542589 else
543590 pool = ERR_PTR (- ENODEV );
591+ if (!IS_ERR (pool ))
592+ dmemcg_pool_get (pool );
544593 spin_unlock (& dmemcg_lock );
545594
546595 if (pool == ERR_PTR (- ENOMEM )) {
@@ -576,6 +625,7 @@ void dmem_cgroup_uncharge(struct dmem_cgroup_pool_state *pool, u64 size)
576625
577626 page_counter_uncharge (& pool -> cnt , size );
578627 css_put (& pool -> cs -> css );
628+ dmemcg_pool_put (pool );
579629}
580630EXPORT_SYMBOL_GPL (dmem_cgroup_uncharge );
581631
@@ -627,7 +677,9 @@ int dmem_cgroup_try_charge(struct dmem_cgroup_region *region, u64 size,
627677 if (ret_limit_pool ) {
628678 * ret_limit_pool = container_of (fail , struct dmem_cgroup_pool_state , cnt );
629679 css_get (& (* ret_limit_pool )-> cs -> css );
680+ dmemcg_pool_get (* ret_limit_pool );
630681 }
682+ dmemcg_pool_put (pool );
631683 ret = - EAGAIN ;
632684 goto err ;
633685 }
@@ -700,6 +752,9 @@ static ssize_t dmemcg_limit_write(struct kernfs_open_file *of,
700752 if (!region_name [0 ])
701753 continue ;
702754
755+ if (!options || !* options )
756+ return - EINVAL ;
757+
703758 rcu_read_lock ();
704759 region = dmemcg_get_region_by_name (region_name );
705760 rcu_read_unlock ();
@@ -719,6 +774,7 @@ static ssize_t dmemcg_limit_write(struct kernfs_open_file *of,
719774
720775 /* And commit */
721776 apply (pool , new_limit );
777+ dmemcg_pool_put (pool );
722778
723779out_put :
724780 kref_put (& region -> ref , dmemcg_free_region );
0 commit comments