[PATCH 1/1] userfaultfd: use RCU to free the task struct when fork fails if MEMCG

From: Andrea Arcangeli
Date: Tue Mar 05 2019 - 19:21:37 EST


MEMCG depends on the task structure not to be freed under
rcu_read_lock() in get_mem_cgroup_from_mm() after it dereferences
mm->owner.

A better fix would be to avoid registering forked vmas in userfaultfd
contexts reported to the monitor, if case fork ends up failing.

Signed-off-by: Andrea Arcangeli <aarcange@xxxxxxxxxx>
---
kernel/fork.c | 34 ++++++++++++++++++++++++++++++++--
1 file changed, 32 insertions(+), 2 deletions(-)

diff --git a/kernel/fork.c b/kernel/fork.c
index eb9953c82104..3bcbb361ffbc 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -953,6 +953,15 @@ static void mm_init_aio(struct mm_struct *mm)
#endif
}

+static __always_inline void mm_clear_owner(struct mm_struct *mm,
+ struct task_struct *p)
+{
+#ifdef CONFIG_MEMCG
+ if (mm->owner == p)
+ mm->owner = NULL;
+#endif
+}
+
static void mm_init_owner(struct mm_struct *mm, struct task_struct *p)
{
#ifdef CONFIG_MEMCG
@@ -1345,6 +1354,7 @@ static struct mm_struct *dup_mm(struct task_struct *tsk)
free_pt:
/* don't put binfmt in mmput, we haven't got module yet */
mm->binfmt = NULL;
+ mm_init_owner(mm, NULL);
mmput(mm);

fail_nomem:
@@ -1676,6 +1686,24 @@ static inline void rcu_copy_process(struct task_struct *p)
#endif /* #ifdef CONFIG_TASKS_RCU */
}

+#ifdef CONFIG_MEMCG
+static void __delayed_free_task(struct rcu_head *rhp)
+{
+ struct task_struct *tsk = container_of(rhp, struct task_struct, rcu);
+
+ free_task(tsk);
+}
+#endif /* CONFIG_MEMCG */
+
+static __always_inline void delayed_free_task(struct task_struct *tsk)
+{
+#ifdef CONFIG_MEMCG
+ call_rcu(&tsk->rcu, __delayed_free_task);
+#else /* CONFIG_MEMCG */
+ free_task(tsk);
+#endif /* CONFIG_MEMCG */
+}
+
/*
* This creates a new process as a copy of the old one,
* but does not actually start it yet.
@@ -2137,8 +2165,10 @@ static __latent_entropy struct task_struct *copy_process(
bad_fork_cleanup_namespaces:
exit_task_namespaces(p);
bad_fork_cleanup_mm:
- if (p->mm)
+ if (p->mm) {
+ mm_clear_owner(p->mm, p);
mmput(p->mm);
+ }
bad_fork_cleanup_signal:
if (!(clone_flags & CLONE_THREAD))
free_signal_struct(p->signal);
@@ -2169,7 +2199,7 @@ static __latent_entropy struct task_struct *copy_process(
bad_fork_free:
p->state = TASK_DEAD;
put_task_stack(p);
- free_task(p);
+ delayed_free_task(p);
fork_out:
spin_lock_irq(&current->sighand->siglock);
hlist_del_init(&delayed.node);