[PATCH] memcg: Protect whole mem_cgroup_scan_tasks() using RCU.

From: Tetsuo Handa
Date: Fri Jun 14 2019 - 21:06:00 EST


syzbot is reporting a GFP at for_each_thread() from a memcg OOM event [1].
While select_bad_process() in a global OOM event traverses whole threads
under RCU, select_bad_process() in a memcg OOM event is traversing threads
without RCU, and I guess that this can result in traversing bogus pointer.

Suppose a process containing three threads T1, T2, T3 is in a memcg.
T3 invokes memcg OOM killer, and starts traversing from T1. T3 elevates
refcount on T1, but T3 is preempted before oom_unkillable_task(T1) check.
Then, T1 reaches do_exit() and T1 does list_del_rcu(&T1->thread_node).

do_exit() {
cgroup_exit() {
css_set_move_task(tsk, cset, NULL, false);
}
exit_notify() {
release_task() {
__exit_signal() {
__unhash_process() {
list_del_rcu(&p->thread_node);
}
}
}
}
}

Then, T2 also reaches do_exit() and does list_del_rcu(&T2->thread_node).
Since the refcount of T1 was kept elevated by T3, T1 cannot be freed. But
since the refcount of T2 was not elevated by T3, T2 can complete do_exit()
and T2 is freed as soon as RCU grace period elapsed. At this point, since
T1 was removed from thread group before T2 was removed, T1's next thread
remains already freed T2. If memory used for T2 was reallocated before T3
resumes execution, accessing T1's next thread will not be reported as
use-after-free but memory referenced as T1's next thread contains bogus
values.

Thus, I think that the rule is: when traversing threads inside a section
between css_task_iter_start() and css_task_iter_end(), each thread must
not involve e.g. for_each_thread() unless whole section is protected by
RCU.

[1] https://syzkaller.appspot.com/bug?id=4559bc383e7c73a35bc6c8336557635459fb7a62

Signed-off-by: Tetsuo Handa <penguin-kernel@xxxxxxxxxxxxxxxxxxx>
Reported-by: syzbot <syzbot+d0fc9d3c166bc5e4a94b@xxxxxxxxxxxxxxxxxxxxxxxxx>
---
mm/memcontrol.c | 2 ++
1 file changed, 2 insertions(+)

diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index ba9138a..8e01f01 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -1159,6 +1159,7 @@ int mem_cgroup_scan_tasks(struct mem_cgroup *memcg,

BUG_ON(memcg == root_mem_cgroup);

+ rcu_read_lock();
for_each_mem_cgroup_tree(iter, memcg) {
struct css_task_iter it;
struct task_struct *task;
@@ -1172,6 +1173,7 @@ int mem_cgroup_scan_tasks(struct mem_cgroup *memcg,
break;
}
}
+ rcu_read_unlock();
return ret;
}

--
1.8.3.1