[PATCH 4/4] [RFC!] mm: 'struct mm_struct' reference counting debugging

From: Vegard Nossum
Date: Fri Dec 16 2016 - 04:59:53 EST


Reference counting bugs are hard to debug by their nature since the actual
manifestation of one can occur very far from where the error is introduced
(e.g. a missing get() only manifest as a use-after-free when the reference
count prematurely drops to 0, which could be arbitrarily long after where
the get() should have happened if there are other users). I wrote this patch
to try to track down a suspected 'mm_struct' reference counting bug.

The basic idea is to keep track of all references, not just with a reference
counter, but with an actual reference _list_. Whenever you get() or put() a
reference, you also add or remove yourself, respectively, from the reference
list. This really helps debugging because (for example) you always put a
specific reference, meaning that if that reference was not yours to put, you
will notice it immediately (rather than when the reference counter goes to 0
and you still have an active reference).

The main interface is in <linux/mm_ref_types.h> and <linux/mm_ref.h>, while
the implementation lives in mm/mm_ref.c. Since 'struct mm_struct' has both
->mm_users and ->mm_count, we introduce helpers for both of them, but use
the same data structure for each (struct mm_ref). The low-level rules (i.e.
the ones we have to follow, but which nobody else should really have to
care about since they use the higher-level interface) are:

- after incrementing ->mm_count you also have to call get_mm_ref()

- before decrementing ->mm_count you also have to call put_mm_ref()

- after incrementing ->mm_users you also have to call get_mm_users_ref()

- before decrementing ->mm_users you also have to call put_mm_users_ref()

The rules that most of the rest of the kernel will care about are:

- functions that acquire and return a mm_struct should take a
'struct mm_ref *' which it can pass on to mmget()/mmgrab()/etc.

- functions that release an mm_struct passed as a parameter should also
take a 'struct mm_ref *' which it can pass on to mmput()/mmdrop()/etc.

- any function that temporarily acquires a mm_struct reference should
use MM_REF() to define an on-stack reference and pass it on to
mmget()/mmput()/mmgrab()/mmdrop()/etc.

- any structure that holds an mm_struct pointer must also include a
'struct mm_ref' member; when the mm_struct pointer is modified you
would typically also call mmget()/mmgrab()/mmput()/mmdrop() and they
should be called with this mm_ref

- you can convert (for example) an on-stack reference to an in-struct
reference using move_mm_ref(). This is semantically equivalent to
(atomically) taking the new reference and dropping the old one, but
doesn't actually need to modify the reference count

I don't really have any delusions about getting this into mainline
(_especially_ not without a CONFIG_MM_REF toggle and zero impact in the =n
case), but I'm posting it in case somebody would find it useful and maybe
to start a discussion about whether this is something that can be usefully
generalized to other core data structures with complicated
reference/ownership models.

The patch really does make it very explicit who holds every reference
taken and where references are implicitly transferred, for example in
finish_task_switch() where the ownership of the reference to 'oldmm' is
implicitly transferred from 'prev->mm' to 'rq->prev_mm', or in
flush_old_exec() where the ownership of the 'bprm->mm' is implicitly
transferred from 'bprm' to 'current->mm'. These ones are a bit subtle
because there is no explicit get()/put() in the code.

There are some users which haven't been converted by this patch (and
many more that haven't been tested) -- x86-64 defconfig should work out of
the box, though. The conversion for the rest of the kernel should be
mostly straightforward (the main challenge was fork/exec).

Thanks-to: Rik van Riel <riel@xxxxxxxxxx>
Thanks-to: Matthew Wilcox <mawilcox@xxxxxxxxxxxxx>
Thanks-to: Peter Zijlstra <peterz@xxxxxxxxxxxxx>
Cc: Andrew Morton <akpm@xxxxxxxxxxxxxxxxxxxx>
Cc: Michal Hocko <mhocko@xxxxxxxx>
Cc: Al Viro <viro@xxxxxxxxxxxxxxxxxx>
Cc: Ingo Molnar <mingo@xxxxxxxxxx>
Cc: Linus Torvalds <torvalds@xxxxxxxxxxxxxxxxxxxx>
Signed-off-by: Vegard Nossum <vegard.nossum@xxxxxxxxxx>
---
arch/x86/kernel/cpu/common.c | 2 +-
drivers/gpu/drm/i915/i915_gem_userptr.c | 25 +++--
drivers/vhost/vhost.c | 7 +-
drivers/vhost/vhost.h | 1 +
fs/exec.c | 20 +++-
fs/proc/array.c | 15 +--
fs/proc/base.c | 97 ++++++++++++-------
fs/proc/internal.h | 4 +-
fs/proc/task_mmu.c | 50 +++++++---
include/linux/binfmts.h | 1 +
include/linux/init_task.h | 1 +
include/linux/kvm_host.h | 3 +
include/linux/mm_ref.h | 48 ++++++++++
include/linux/mm_ref_types.h | 41 ++++++++
include/linux/mm_types.h | 8 ++
include/linux/mmu_notifier.h | 8 +-
include/linux/sched.h | 42 +++++---
kernel/cpuset.c | 22 +++--
kernel/events/core.c | 5 +-
kernel/exit.c | 5 +-
kernel/fork.c | 51 ++++++----
kernel/futex.c | 124 +++++++++++++-----------
kernel/sched/core.c | 14 ++-
kernel/sched/sched.h | 1 +
kernel/sys.c | 5 +-
kernel/trace/trace_output.c | 5 +-
kernel/tsacct.c | 5 +-
mm/Makefile | 2 +-
mm/init-mm.c | 6 ++
mm/memory.c | 5 +-
mm/mempolicy.c | 5 +-
mm/migrate.c | 5 +-
mm/mm_ref.c | 163 ++++++++++++++++++++++++++++++++
mm/mmu_context.c | 9 +-
mm/mmu_notifier.c | 20 ++--
mm/oom_kill.c | 12 ++-
mm/process_vm_access.c | 5 +-
mm/swapfile.c | 29 +++---
mm/util.c | 5 +-
virt/kvm/async_pf.c | 9 +-
virt/kvm/kvm_main.c | 16 +++-
41 files changed, 668 insertions(+), 233 deletions(-)
create mode 100644 include/linux/mm_ref.h
create mode 100644 include/linux/mm_ref_types.h
create mode 100644 mm/mm_ref.c

diff --git a/arch/x86/kernel/cpu/common.c b/arch/x86/kernel/cpu/common.c
index b580da4582e1..edf16f695130 100644
--- a/arch/x86/kernel/cpu/common.c
+++ b/arch/x86/kernel/cpu/common.c
@@ -1555,7 +1555,7 @@ void cpu_init(void)
for (i = 0; i <= IO_BITMAP_LONGS; i++)
t->io_bitmap[i] = ~0UL;

- mmgrab(&init_mm);
+ mmgrab(&init_mm, &me->mm_ref);
me->active_mm = &init_mm;
BUG_ON(me->mm);
enter_lazy_tlb(&init_mm, me);
diff --git a/drivers/gpu/drm/i915/i915_gem_userptr.c b/drivers/gpu/drm/i915/i915_gem_userptr.c
index e97f9ade99fc..498d311e1a80 100644
--- a/drivers/gpu/drm/i915/i915_gem_userptr.c
+++ b/drivers/gpu/drm/i915/i915_gem_userptr.c
@@ -34,8 +34,10 @@

struct i915_mm_struct {
struct mm_struct *mm;
+ struct mm_ref mm_ref;
struct drm_i915_private *i915;
struct i915_mmu_notifier *mn;
+ struct mm_ref mn_ref;
struct hlist_node node;
struct kref kref;
struct work_struct work;
@@ -159,7 +161,7 @@ static const struct mmu_notifier_ops i915_gem_userptr_notifier = {
};

static struct i915_mmu_notifier *
-i915_mmu_notifier_create(struct mm_struct *mm)
+i915_mmu_notifier_create(struct mm_struct *mm, struct mm_ref *mm_ref)
{
struct i915_mmu_notifier *mn;
int ret;
@@ -178,7 +180,7 @@ i915_mmu_notifier_create(struct mm_struct *mm)
}

/* Protected by mmap_sem (write-lock) */
- ret = __mmu_notifier_register(&mn->mn, mm);
+ ret = __mmu_notifier_register(&mn->mn, mm, mm_ref);
if (ret) {
destroy_workqueue(mn->wq);
kfree(mn);
@@ -217,7 +219,7 @@ i915_mmu_notifier_find(struct i915_mm_struct *mm)
down_write(&mm->mm->mmap_sem);
mutex_lock(&mm->i915->mm_lock);
if ((mn = mm->mn) == NULL) {
- mn = i915_mmu_notifier_create(mm->mm);
+ mn = i915_mmu_notifier_create(mm->mm, &mm->mn_ref);
if (!IS_ERR(mn))
mm->mn = mn;
}
@@ -260,12 +262,12 @@ i915_gem_userptr_init__mmu_notifier(struct drm_i915_gem_object *obj,

static void
i915_mmu_notifier_free(struct i915_mmu_notifier *mn,
- struct mm_struct *mm)
+ struct mm_struct *mm, struct mm_ref *mm_ref)
{
if (mn == NULL)
return;

- mmu_notifier_unregister(&mn->mn, mm);
+ mmu_notifier_unregister(&mn->mn, mm, mm_ref);
destroy_workqueue(mn->wq);
kfree(mn);
}
@@ -341,9 +343,11 @@ i915_gem_userptr_init__mm_struct(struct drm_i915_gem_object *obj)
mm->i915 = to_i915(obj->base.dev);

mm->mm = current->mm;
- mmgrab(current->mm);
+ INIT_MM_REF(&mm->mm_ref);
+ mmgrab(current->mm, &mm->mm_ref);

mm->mn = NULL;
+ INIT_MM_REF(&mm->mn_ref);

/* Protected by dev_priv->mm_lock */
hash_add(dev_priv->mm_structs,
@@ -361,8 +365,8 @@ static void
__i915_mm_struct_free__worker(struct work_struct *work)
{
struct i915_mm_struct *mm = container_of(work, typeof(*mm), work);
- i915_mmu_notifier_free(mm->mn, mm->mm);
- mmdrop(mm->mm);
+ i915_mmu_notifier_free(mm->mn, mm->mm, &mm->mn_ref);
+ mmdrop(mm->mm, &mm->mm_ref);
kfree(mm);
}

@@ -508,13 +512,14 @@ __i915_gem_userptr_get_pages_worker(struct work_struct *_work)
pvec = drm_malloc_gfp(npages, sizeof(struct page *), GFP_TEMPORARY);
if (pvec != NULL) {
struct mm_struct *mm = obj->userptr.mm->mm;
+ MM_REF(mm_ref);
unsigned int flags = 0;

if (!obj->userptr.read_only)
flags |= FOLL_WRITE;

ret = -EFAULT;
- if (mmget_not_zero(mm)) {
+ if (mmget_not_zero(mm, &mm_ref)) {
down_read(&mm->mmap_sem);
while (pinned < npages) {
ret = get_user_pages_remote
@@ -529,7 +534,7 @@ __i915_gem_userptr_get_pages_worker(struct work_struct *_work)
pinned += ret;
}
up_read(&mm->mmap_sem);
- mmput(mm);
+ mmput(mm, &mm_ref);
}
}

diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index c6f2d89c0e97..4470abf94fe8 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -407,6 +407,7 @@ void vhost_dev_init(struct vhost_dev *dev,
dev->umem = NULL;
dev->iotlb = NULL;
dev->mm = NULL;
+ INIT_MM_REF(&dev->mm_ref);
dev->worker = NULL;
init_llist_head(&dev->work_list);
init_waitqueue_head(&dev->wait);
@@ -483,7 +484,7 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
}

/* No owner, become one */
- dev->mm = get_task_mm(current);
+ dev->mm = get_task_mm(current, &dev->mm_ref);
worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid);
if (IS_ERR(worker)) {
err = PTR_ERR(worker);
@@ -507,7 +508,7 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
dev->worker = NULL;
err_worker:
if (dev->mm)
- mmput(dev->mm);
+ mmput(dev->mm, &dev->mm_ref);
dev->mm = NULL;
err_mm:
return err;
@@ -639,7 +640,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
dev->worker = NULL;
}
if (dev->mm)
- mmput(dev->mm);
+ mmput(dev->mm, &dev->mm_ref);
dev->mm = NULL;
}
EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 78f3c5fc02e4..64fdcfa9cf67 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -151,6 +151,7 @@ struct vhost_msg_node {

struct vhost_dev {
struct mm_struct *mm;
+ struct mm_ref mm_ref;
struct mutex mutex;
struct vhost_virtqueue **vqs;
int nvqs;
diff --git a/fs/exec.c b/fs/exec.c
index 4e497b9ee71e..13afedb2821d 100644
--- a/fs/exec.c
+++ b/fs/exec.c
@@ -380,7 +380,7 @@ static int bprm_mm_init(struct linux_binprm *bprm)
int err;
struct mm_struct *mm = NULL;

- bprm->mm = mm = mm_alloc();
+ bprm->mm = mm = mm_alloc(&bprm->mm_ref);
err = -ENOMEM;
if (!mm)
goto err;
@@ -394,7 +394,7 @@ static int bprm_mm_init(struct linux_binprm *bprm)
err:
if (mm) {
bprm->mm = NULL;
- mmdrop(mm);
+ mmdrop(mm, &bprm->mm_ref);
}

return err;
@@ -996,6 +996,8 @@ static int exec_mmap(struct mm_struct *mm)
{
struct task_struct *tsk;
struct mm_struct *old_mm, *active_mm;
+ MM_REF(old_mm_ref);
+ MM_REF(active_mm_ref);

/* Notify parent that we're no longer interested in the old VM */
tsk = current;
@@ -1015,9 +1017,14 @@ static int exec_mmap(struct mm_struct *mm)
up_read(&old_mm->mmap_sem);
return -EINTR;
}
+
+ move_mm_users_ref(old_mm, &current->mm_ref, &old_mm_ref);
}
task_lock(tsk);
+
active_mm = tsk->active_mm;
+ if (!old_mm)
+ move_mm_ref(active_mm, &tsk->mm_ref, &active_mm_ref);
tsk->mm = mm;
tsk->active_mm = mm;
activate_mm(active_mm, mm);
@@ -1029,10 +1036,10 @@ static int exec_mmap(struct mm_struct *mm)
BUG_ON(active_mm != old_mm);
setmax_mm_hiwater_rss(&tsk->signal->maxrss, old_mm);
mm_update_next_owner(old_mm);
- mmput(old_mm);
+ mmput(old_mm, &old_mm_ref);
return 0;
}
- mmdrop(active_mm);
+ mmdrop(active_mm, &active_mm_ref);
return 0;
}

@@ -1258,6 +1265,7 @@ int flush_old_exec(struct linux_binprm * bprm)
if (retval)
goto out;

+ move_mm_ref(bprm->mm, &bprm->mm_ref, &current->mm_ref);
bprm->mm = NULL; /* We're using it now */

set_fs(USER_DS);
@@ -1674,6 +1682,8 @@ static int do_execveat_common(int fd, struct filename *filename,
if (!bprm)
goto out_files;

+ INIT_MM_REF(&bprm->mm_ref);
+
retval = prepare_bprm_creds(bprm);
if (retval)
goto out_free;
@@ -1760,7 +1770,7 @@ static int do_execveat_common(int fd, struct filename *filename,
out:
if (bprm->mm) {
acct_arg_size(bprm, 0);
- mmput(bprm->mm);
+ mmput(bprm->mm, &bprm->mm_ref);
}

out_unmark:
diff --git a/fs/proc/array.c b/fs/proc/array.c
index 81818adb8e9e..3e02be82c2f4 100644
--- a/fs/proc/array.c
+++ b/fs/proc/array.c
@@ -367,14 +367,15 @@ static void task_cpus_allowed(struct seq_file *m, struct task_struct *task)
int proc_pid_status(struct seq_file *m, struct pid_namespace *ns,
struct pid *pid, struct task_struct *task)
{
- struct mm_struct *mm = get_task_mm(task);
+ MM_REF(mm_ref);
+ struct mm_struct *mm = get_task_mm(task, &mm_ref);

task_name(m, task);
task_state(m, ns, pid, task);

if (mm) {
task_mem(m, mm);
- mmput(mm);
+ mmput(mm, &mm_ref);
}
task_sig(m, task);
task_cap(m, task);
@@ -397,6 +398,7 @@ static int do_task_stat(struct seq_file *m, struct pid_namespace *ns,
int num_threads = 0;
int permitted;
struct mm_struct *mm;
+ MM_REF(mm_ref);
unsigned long long start_time;
unsigned long cmin_flt = 0, cmaj_flt = 0;
unsigned long min_flt = 0, maj_flt = 0;
@@ -409,7 +411,7 @@ static int do_task_stat(struct seq_file *m, struct pid_namespace *ns,
state = *get_task_state(task);
vsize = eip = esp = 0;
permitted = ptrace_may_access(task, PTRACE_MODE_READ_FSCREDS | PTRACE_MODE_NOAUDIT);
- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
if (mm) {
vsize = task_vsize(mm);
/*
@@ -562,7 +564,7 @@ static int do_task_stat(struct seq_file *m, struct pid_namespace *ns,

seq_putc(m, '\n');
if (mm)
- mmput(mm);
+ mmput(mm, &mm_ref);
return 0;
}

@@ -582,11 +584,12 @@ int proc_pid_statm(struct seq_file *m, struct pid_namespace *ns,
struct pid *pid, struct task_struct *task)
{
unsigned long size = 0, resident = 0, shared = 0, text = 0, data = 0;
- struct mm_struct *mm = get_task_mm(task);
+ MM_REF(mm_ref);
+ struct mm_struct *mm = get_task_mm(task, &mm_ref);

if (mm) {
size = task_statm(mm, &shared, &text, &data, &resident);
- mmput(mm);
+ mmput(mm, &mm_ref);
}
/*
* For quick read, open code by putting numbers directly
diff --git a/fs/proc/base.c b/fs/proc/base.c
index 87fd5bf07578..9c8bbfc0ab45 100644
--- a/fs/proc/base.c
+++ b/fs/proc/base.c
@@ -201,6 +201,7 @@ static ssize_t proc_pid_cmdline_read(struct file *file, char __user *buf,
{
struct task_struct *tsk;
struct mm_struct *mm;
+ MM_REF(mm_ref);
char *page;
unsigned long count = _count;
unsigned long arg_start, arg_end, env_start, env_end;
@@ -214,7 +215,7 @@ static ssize_t proc_pid_cmdline_read(struct file *file, char __user *buf,
tsk = get_proc_task(file_inode(file));
if (!tsk)
return -ESRCH;
- mm = get_task_mm(tsk);
+ mm = get_task_mm(tsk, &mm_ref);
put_task_struct(tsk);
if (!mm)
return 0;
@@ -389,7 +390,7 @@ static ssize_t proc_pid_cmdline_read(struct file *file, char __user *buf,
out_free_page:
free_page((unsigned long)page);
out_mmput:
- mmput(mm);
+ mmput(mm, &mm_ref);
if (rv > 0)
*pos += rv;
return rv;
@@ -784,34 +785,50 @@ static const struct file_operations proc_single_file_operations = {
};


-struct mm_struct *proc_mem_open(struct inode *inode, unsigned int mode)
+struct mm_struct *proc_mem_open(struct inode *inode, unsigned int mode, struct mm_ref *mm_ref)
{
struct task_struct *task = get_proc_task(inode);
struct mm_struct *mm = ERR_PTR(-ESRCH);
+ MM_REF(tmp_ref);

if (task) {
- mm = mm_access(task, mode | PTRACE_MODE_FSCREDS);
+ mm = mm_access(task, mode | PTRACE_MODE_FSCREDS, &tmp_ref);
put_task_struct(task);

if (!IS_ERR_OR_NULL(mm)) {
/* ensure this mm_struct can't be freed */
- mmgrab(mm);
+ mmgrab(mm, mm_ref);
/* but do not pin its memory */
- mmput(mm);
+ mmput(mm, &tmp_ref);
}
}

return mm;
}

+struct mem_private {
+ struct mm_struct *mm;
+ struct mm_ref mm_ref;
+};
+
static int __mem_open(struct inode *inode, struct file *file, unsigned int mode)
{
- struct mm_struct *mm = proc_mem_open(inode, mode);
+ struct mem_private *priv;
+ struct mm_struct *mm;

- if (IS_ERR(mm))
+ priv = kmalloc(sizeof(struct mem_private), GFP_KERNEL);
+ if (!priv)
+ return -ENOMEM;
+
+ INIT_MM_REF(&priv->mm_ref);
+ mm = proc_mem_open(inode, mode, &priv->mm_ref);
+ if (IS_ERR(mm)) {
+ kfree(priv);
return PTR_ERR(mm);
+ }

- file->private_data = mm;
+ priv->mm = mm;
+ file->private_data = priv;
return 0;
}

@@ -828,7 +845,9 @@ static int mem_open(struct inode *inode, struct file *file)
static ssize_t mem_rw(struct file *file, char __user *buf,
size_t count, loff_t *ppos, int write)
{
- struct mm_struct *mm = file->private_data;
+ struct mem_private *priv = file->private_data;
+ struct mm_struct *mm = priv->mm;
+ MM_REF(mm_ref);
unsigned long addr = *ppos;
ssize_t copied;
char *page;
@@ -842,7 +861,7 @@ static ssize_t mem_rw(struct file *file, char __user *buf,
return -ENOMEM;

copied = 0;
- if (!mmget_not_zero(mm))
+ if (!mmget_not_zero(mm, &mm_ref))
goto free;

/* Maybe we should limit FOLL_FORCE to actual ptrace users? */
@@ -877,7 +896,7 @@ static ssize_t mem_rw(struct file *file, char __user *buf,
}
*ppos = addr;

- mmput(mm);
+ mmput(mm, &mm_ref);
free:
free_page((unsigned long) page);
return copied;
@@ -913,9 +932,11 @@ loff_t mem_lseek(struct file *file, loff_t offset, int orig)

static int mem_release(struct inode *inode, struct file *file)
{
- struct mm_struct *mm = file->private_data;
+ struct mem_private *priv = file->private_data;
+ struct mm_struct *mm = priv->mm;
if (mm)
- mmdrop(mm);
+ mmdrop(mm, &priv->mm_ref);
+ kfree(priv);
return 0;
}

@@ -935,10 +956,12 @@ static int environ_open(struct inode *inode, struct file *file)
static ssize_t environ_read(struct file *file, char __user *buf,
size_t count, loff_t *ppos)
{
+ struct mem_private *priv = file->private_data;
char *page;
unsigned long src = *ppos;
int ret = 0;
- struct mm_struct *mm = file->private_data;
+ struct mm_struct *mm = priv->mm;
+ MM_REF(mm_ref);
unsigned long env_start, env_end;

/* Ensure the process spawned far enough to have an environment. */
@@ -950,7 +973,7 @@ static ssize_t environ_read(struct file *file, char __user *buf,
return -ENOMEM;

ret = 0;
- if (!mmget_not_zero(mm))
+ if (!mmget_not_zero(mm, &mm_ref))
goto free;

down_read(&mm->mmap_sem);
@@ -988,7 +1011,7 @@ static ssize_t environ_read(struct file *file, char __user *buf,
count -= retval;
}
*ppos = src;
- mmput(mm);
+ mmput(mm, &mm_ref);

free:
free_page((unsigned long) page);
@@ -1010,7 +1033,8 @@ static int auxv_open(struct inode *inode, struct file *file)
static ssize_t auxv_read(struct file *file, char __user *buf,
size_t count, loff_t *ppos)
{
- struct mm_struct *mm = file->private_data;
+ struct mem_private *priv = file->private_data;
+ struct mm_struct *mm = priv->mm;
unsigned int nwords = 0;

if (!mm)
@@ -1053,6 +1077,7 @@ static int __set_oom_adj(struct file *file, int oom_adj, bool legacy)
{
static DEFINE_MUTEX(oom_adj_mutex);
struct mm_struct *mm = NULL;
+ MM_REF(mm_ref);
struct task_struct *task;
int err = 0;

@@ -1093,7 +1118,7 @@ static int __set_oom_adj(struct file *file, int oom_adj, bool legacy)
if (p) {
if (atomic_read(&p->mm->mm_users) > 1) {
mm = p->mm;
- mmgrab(mm);
+ mmgrab(mm, &mm_ref);
}
task_unlock(p);
}
@@ -1129,7 +1154,7 @@ static int __set_oom_adj(struct file *file, int oom_adj, bool legacy)
task_unlock(p);
}
rcu_read_unlock();
- mmdrop(mm);
+ mmdrop(mm, &mm_ref);
}
err_unlock:
mutex_unlock(&oom_adj_mutex);
@@ -1875,6 +1900,7 @@ static int map_files_d_revalidate(struct dentry *dentry, unsigned int flags)
unsigned long vm_start, vm_end;
bool exact_vma_exists = false;
struct mm_struct *mm = NULL;
+ MM_REF(mm_ref);
struct task_struct *task;
const struct cred *cred;
struct inode *inode;
@@ -1888,7 +1914,7 @@ static int map_files_d_revalidate(struct dentry *dentry, unsigned int flags)
if (!task)
goto out_notask;

- mm = mm_access(task, PTRACE_MODE_READ_FSCREDS);
+ mm = mm_access(task, PTRACE_MODE_READ_FSCREDS, &mm_ref);
if (IS_ERR_OR_NULL(mm))
goto out;

@@ -1898,7 +1924,7 @@ static int map_files_d_revalidate(struct dentry *dentry, unsigned int flags)
up_read(&mm->mmap_sem);
}

- mmput(mm);
+ mmput(mm, &mm_ref);

if (exact_vma_exists) {
if (task_dumpable(task)) {
@@ -1933,6 +1959,7 @@ static int map_files_get_link(struct dentry *dentry, struct path *path)
struct vm_area_struct *vma;
struct task_struct *task;
struct mm_struct *mm;
+ MM_REF(mm_ref);
int rc;

rc = -ENOENT;
@@ -1940,7 +1967,7 @@ static int map_files_get_link(struct dentry *dentry, struct path *path)
if (!task)
goto out;

- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
put_task_struct(task);
if (!mm)
goto out;
@@ -1960,7 +1987,7 @@ static int map_files_get_link(struct dentry *dentry, struct path *path)
up_read(&mm->mmap_sem);

out_mmput:
- mmput(mm);
+ mmput(mm, &mm_ref);
out:
return rc;
}
@@ -2034,6 +2061,7 @@ static struct dentry *proc_map_files_lookup(struct inode *dir,
struct task_struct *task;
int result;
struct mm_struct *mm;
+ MM_REF(mm_ref);

result = -ENOENT;
task = get_proc_task(dir);
@@ -2048,7 +2076,7 @@ static struct dentry *proc_map_files_lookup(struct inode *dir,
if (dname_to_vma_addr(dentry, &vm_start, &vm_end))
goto out_put_task;

- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
if (!mm)
goto out_put_task;

@@ -2063,7 +2091,7 @@ static struct dentry *proc_map_files_lookup(struct inode *dir,

out_no_vma:
up_read(&mm->mmap_sem);
- mmput(mm);
+ mmput(mm, &mm_ref);
out_put_task:
put_task_struct(task);
out:
@@ -2082,6 +2110,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx)
struct vm_area_struct *vma;
struct task_struct *task;
struct mm_struct *mm;
+ MM_REF(mm_ref);
unsigned long nr_files, pos, i;
struct flex_array *fa = NULL;
struct map_files_info info;
@@ -2101,7 +2130,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx)
if (!dir_emit_dots(file, ctx))
goto out_put_task;

- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
if (!mm)
goto out_put_task;
down_read(&mm->mmap_sem);
@@ -2132,7 +2161,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx)
if (fa)
flex_array_free(fa);
up_read(&mm->mmap_sem);
- mmput(mm);
+ mmput(mm, &mm_ref);
goto out_put_task;
}
for (i = 0, vma = mm->mmap, pos = 2; vma;
@@ -2164,7 +2193,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx)
}
if (fa)
flex_array_free(fa);
- mmput(mm);
+ mmput(mm, &mm_ref);

out_put_task:
put_task_struct(task);
@@ -2567,6 +2596,7 @@ static ssize_t proc_coredump_filter_read(struct file *file, char __user *buf,
{
struct task_struct *task = get_proc_task(file_inode(file));
struct mm_struct *mm;
+ MM_REF(mm_ref);
char buffer[PROC_NUMBUF];
size_t len;
int ret;
@@ -2575,12 +2605,12 @@ static ssize_t proc_coredump_filter_read(struct file *file, char __user *buf,
return -ESRCH;

ret = 0;
- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
if (mm) {
len = snprintf(buffer, sizeof(buffer), "%08lx\n",
((mm->flags & MMF_DUMP_FILTER_MASK) >>
MMF_DUMP_FILTER_SHIFT));
- mmput(mm);
+ mmput(mm, &mm_ref);
ret = simple_read_from_buffer(buf, count, ppos, buffer, len);
}

@@ -2596,6 +2626,7 @@ static ssize_t proc_coredump_filter_write(struct file *file,
{
struct task_struct *task;
struct mm_struct *mm;
+ MM_REF(mm_ref);
unsigned int val;
int ret;
int i;
@@ -2610,7 +2641,7 @@ static ssize_t proc_coredump_filter_write(struct file *file,
if (!task)
goto out_no_task;

- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
if (!mm)
goto out_no_mm;
ret = 0;
@@ -2622,7 +2653,7 @@ static ssize_t proc_coredump_filter_write(struct file *file,
clear_bit(i + MMF_DUMP_FILTER_SHIFT, &mm->flags);
}

- mmput(mm);
+ mmput(mm, &mm_ref);
out_no_mm:
put_task_struct(task);
out_no_task:
diff --git a/fs/proc/internal.h b/fs/proc/internal.h
index 5378441ec1b7..9aed2e391b15 100644
--- a/fs/proc/internal.h
+++ b/fs/proc/internal.h
@@ -280,6 +280,8 @@ struct proc_maps_private {
struct inode *inode;
struct task_struct *task;
struct mm_struct *mm;
+ struct mm_ref mm_open_ref;
+ struct mm_ref mm_start_ref;
#ifdef CONFIG_MMU
struct vm_area_struct *tail_vma;
#endif
@@ -288,7 +290,7 @@ struct proc_maps_private {
#endif
};

-struct mm_struct *proc_mem_open(struct inode *inode, unsigned int mode);
+struct mm_struct *proc_mem_open(struct inode *inode, unsigned int mode, struct mm_ref *mm_ref);

extern const struct file_operations proc_pid_maps_operations;
extern const struct file_operations proc_tid_maps_operations;
diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c
index c71975293dc8..06ed5d67dd84 100644
--- a/fs/proc/task_mmu.c
+++ b/fs/proc/task_mmu.c
@@ -133,7 +133,7 @@ static void vma_stop(struct proc_maps_private *priv)

release_task_mempolicy(priv);
up_read(&mm->mmap_sem);
- mmput(mm);
+ mmput(mm, &priv->mm_start_ref);
}

static struct vm_area_struct *
@@ -167,7 +167,7 @@ static void *m_start(struct seq_file *m, loff_t *ppos)
return ERR_PTR(-ESRCH);

mm = priv->mm;
- if (!mm || !mmget_not_zero(mm))
+ if (!mm || !mmget_not_zero(mm, &priv->mm_start_ref))
return NULL;

down_read(&mm->mmap_sem);
@@ -232,7 +232,9 @@ static int proc_maps_open(struct inode *inode, struct file *file,
return -ENOMEM;

priv->inode = inode;
- priv->mm = proc_mem_open(inode, PTRACE_MODE_READ);
+ INIT_MM_REF(&priv->mm_open_ref);
+ INIT_MM_REF(&priv->mm_start_ref);
+ priv->mm = proc_mem_open(inode, PTRACE_MODE_READ, &priv->mm_open_ref);
if (IS_ERR(priv->mm)) {
int err = PTR_ERR(priv->mm);

@@ -249,7 +251,7 @@ static int proc_map_release(struct inode *inode, struct file *file)
struct proc_maps_private *priv = seq->private;

if (priv->mm)
- mmdrop(priv->mm);
+ mmdrop(priv->mm, &priv->mm_open_ref);

return seq_release_private(inode, file);
}
@@ -997,6 +999,7 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf,
struct task_struct *task;
char buffer[PROC_NUMBUF];
struct mm_struct *mm;
+ MM_REF(mm_ref);
struct vm_area_struct *vma;
enum clear_refs_types type;
int itype;
@@ -1017,7 +1020,7 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf,
task = get_proc_task(file_inode(file));
if (!task)
return -ESRCH;
- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
if (mm) {
struct clear_refs_private cp = {
.type = type,
@@ -1069,7 +1072,7 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf,
flush_tlb_mm(mm);
up_read(&mm->mmap_sem);
out_mm:
- mmput(mm);
+ mmput(mm, &mm_ref);
}
put_task_struct(task);

@@ -1340,10 +1343,17 @@ static int pagemap_hugetlb_range(pte_t *ptep, unsigned long hmask,
* determine which areas of memory are actually mapped and llseek to
* skip over unmapped regions.
*/
+struct pagemap_private {
+ struct mm_struct *mm;
+ struct mm_ref mm_ref;
+};
+
static ssize_t pagemap_read(struct file *file, char __user *buf,
size_t count, loff_t *ppos)
{
- struct mm_struct *mm = file->private_data;
+ struct pagemap_private *priv = file->private_data;
+ struct mm_struct *mm = priv->mm;
+ MM_REF(mm_ref);
struct pagemapread pm;
struct mm_walk pagemap_walk = {};
unsigned long src;
@@ -1352,7 +1362,7 @@ static ssize_t pagemap_read(struct file *file, char __user *buf,
unsigned long end_vaddr;
int ret = 0, copied = 0;

- if (!mm || !mmget_not_zero(mm))
+ if (!mm || !mmget_not_zero(mm, &mm_ref))
goto out;

ret = -EINVAL;
@@ -1427,28 +1437,40 @@ static ssize_t pagemap_read(struct file *file, char __user *buf,
out_free:
kfree(pm.buffer);
out_mm:
- mmput(mm);
+ mmput(mm, &mm_ref);
out:
return ret;
}

static int pagemap_open(struct inode *inode, struct file *file)
{
+ struct pagemap_private *priv;
struct mm_struct *mm;

- mm = proc_mem_open(inode, PTRACE_MODE_READ);
- if (IS_ERR(mm))
+ priv = kmalloc(sizeof(struct pagemap_private), GFP_KERNEL);
+ if (!priv)
+ return -ENOMEM;
+
+ mm = proc_mem_open(inode, PTRACE_MODE_READ, &priv->mm_ref);
+ if (IS_ERR(mm)) {
+ kfree(priv);
return PTR_ERR(mm);
- file->private_data = mm;
+ }
+
+ priv->mm = mm;
+ file->private_data = priv;
return 0;
}

static int pagemap_release(struct inode *inode, struct file *file)
{
- struct mm_struct *mm = file->private_data;
+ struct pagemap_private *priv = file->private_data;
+ struct mm_struct *mm = priv->mm;

if (mm)
- mmdrop(mm);
+ mmdrop(mm, &priv->mm_ref);
+
+ kfree(priv);
return 0;
}

diff --git a/include/linux/binfmts.h b/include/linux/binfmts.h
index 1303b570b18c..8bee41838bd5 100644
--- a/include/linux/binfmts.h
+++ b/include/linux/binfmts.h
@@ -21,6 +21,7 @@ struct linux_binprm {
struct page *page[MAX_ARG_PAGES];
#endif
struct mm_struct *mm;
+ struct mm_ref mm_ref;
unsigned long p; /* current top of mem */
unsigned int
cred_prepared:1,/* true if creds already prepared (multiple
diff --git a/include/linux/init_task.h b/include/linux/init_task.h
index 325f649d77ff..02c9ecf243d1 100644
--- a/include/linux/init_task.h
+++ b/include/linux/init_task.h
@@ -211,6 +211,7 @@ extern struct task_group root_task_group;
.cpus_allowed = CPU_MASK_ALL, \
.nr_cpus_allowed= NR_CPUS, \
.mm = NULL, \
+ .mm_ref = MM_REF_INIT(tsk.mm_ref), \
.active_mm = &init_mm, \
.restart_block = { \
.fn = do_no_restart_syscall, \
diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index 01c0b9cc3915..635d4a84f03b 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -174,6 +174,7 @@ struct kvm_async_pf {
struct list_head queue;
struct kvm_vcpu *vcpu;
struct mm_struct *mm;
+ struct mm_ref mm_ref;
gva_t gva;
unsigned long addr;
struct kvm_arch_async_pf arch;
@@ -376,6 +377,7 @@ struct kvm {
spinlock_t mmu_lock;
struct mutex slots_lock;
struct mm_struct *mm; /* userspace tied to this vm */
+ struct mm_ref mm_ref;
struct kvm_memslots *memslots[KVM_ADDRESS_SPACE_NUM];
struct srcu_struct srcu;
struct srcu_struct irq_srcu;
@@ -424,6 +426,7 @@ struct kvm {

#if defined(CONFIG_MMU_NOTIFIER) && defined(KVM_ARCH_WANT_MMU_NOTIFIER)
struct mmu_notifier mmu_notifier;
+ struct mm_ref mmu_notifier_ref;
unsigned long mmu_notifier_seq;
long mmu_notifier_count;
#endif
diff --git a/include/linux/mm_ref.h b/include/linux/mm_ref.h
new file mode 100644
index 000000000000..0de29bd64542
--- /dev/null
+++ b/include/linux/mm_ref.h
@@ -0,0 +1,48 @@
+#ifndef LINUX_MM_REF_H
+#define LINUX_MM_REF_H
+
+#include <linux/mm_types.h>
+#include <linux/mm_ref_types.h>
+
+struct mm_struct;
+
+extern void INIT_MM_REF(struct mm_ref *ref);
+
+extern void _get_mm_ref(struct mm_struct *mm, struct list_head *list,
+ struct mm_ref *ref);
+extern void _put_mm_ref(struct mm_struct *mm, struct list_head *list,
+ struct mm_ref *ref);
+extern void _move_mm_ref(struct mm_struct *mm, struct list_head *list,
+ struct mm_ref *old_ref, struct mm_ref *new_ref);
+
+static inline void get_mm_ref(struct mm_struct *mm, struct mm_ref *ref)
+{
+ _get_mm_ref(mm, &mm->mm_count_list, ref);
+}
+
+static inline void put_mm_ref(struct mm_struct *mm, struct mm_ref *ref)
+{
+ _put_mm_ref(mm, &mm->mm_count_list, ref);
+}
+
+static inline void move_mm_ref(struct mm_struct *mm, struct mm_ref *old_ref, struct mm_ref *new_ref)
+{
+ _move_mm_ref(mm, &mm->mm_count_list, old_ref, new_ref);
+}
+
+static inline void get_mm_users_ref(struct mm_struct *mm, struct mm_ref *ref)
+{
+ _get_mm_ref(mm, &mm->mm_users_list, ref);
+}
+
+static inline void put_mm_users_ref(struct mm_struct *mm, struct mm_ref *ref)
+{
+ _put_mm_ref(mm, &mm->mm_users_list, ref);
+}
+
+static inline void move_mm_users_ref(struct mm_struct *mm, struct mm_ref *old_ref, struct mm_ref *new_ref)
+{
+ _move_mm_ref(mm, &mm->mm_users_list, old_ref, new_ref);
+}
+
+#endif
diff --git a/include/linux/mm_ref_types.h b/include/linux/mm_ref_types.h
new file mode 100644
index 000000000000..5c45995688bd
--- /dev/null
+++ b/include/linux/mm_ref_types.h
@@ -0,0 +1,41 @@
+#ifndef LINUX_MM_REF_TYPES_H
+#define LINUX_MM_REF_TYPES_H
+
+#include <linux/list.h>
+#include <linux/stacktrace.h>
+
+#define NR_MM_REF_STACK_ENTRIES 10
+
+enum mm_ref_state {
+ /*
+ * Pick 0 as uninitialized so we have a chance at catching
+ * uninitialized references by noticing that they are zero.
+ *
+ * The rest are random 32-bit integers.
+ */
+ MM_REF_UNINITIALIZED = 0,
+ MM_REF_INITIALIZED = 0x28076894UL,
+ MM_REF_ACTIVE = 0xdaf46189UL,
+ MM_REF_INACTIVE = 0xf5358bafUL,
+};
+
+struct mm_ref {
+ /*
+ * See ->mm_users_list/->mm_count_list in struct mm_struct.
+ * Access is protected by ->mm_refs_lock.
+ */
+ struct list_head list_entry;
+
+ enum mm_ref_state state;
+ int pid;
+ struct stack_trace trace;
+ unsigned long trace_entries[NR_MM_REF_STACK_ENTRIES];
+};
+
+#define MM_REF_INIT(name) \
+ { LIST_HEAD_INIT(name.list_entry), MM_REF_INITIALIZED, }
+
+#define MM_REF(name) \
+ struct mm_ref name = MM_REF_INIT(name)
+
+#endif
diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 4a8acedf4b7d..520cde63305d 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -13,6 +13,7 @@
#include <linux/uprobes.h>
#include <linux/page-flags-layout.h>
#include <linux/workqueue.h>
+#include <linux/mm_ref_types.h>
#include <asm/page.h>
#include <asm/mmu.h>

@@ -407,8 +408,14 @@ struct mm_struct {
unsigned long task_size; /* size of task vm space */
unsigned long highest_vm_end; /* highest vma end address */
pgd_t * pgd;
+
+ spinlock_t mm_refs_lock; /* Protects mm_users_list and mm_count_list */
atomic_t mm_users; /* How many users with user space? */
+ struct list_head mm_users_list;
+ struct mm_ref mm_users_ref;
atomic_t mm_count; /* How many references to "struct mm_struct" (users count as 1) */
+ struct list_head mm_count_list;
+
atomic_long_t nr_ptes; /* PTE page table pages */
#if CONFIG_PGTABLE_LEVELS > 2
atomic_long_t nr_pmds; /* PMD page table pages */
@@ -516,6 +523,7 @@ struct mm_struct {
atomic_long_t hugetlb_usage;
#endif
struct work_struct async_put_work;
+ struct mm_ref async_put_ref;
};

static inline void mm_init_cpumask(struct mm_struct *mm)
diff --git a/include/linux/mmu_notifier.h b/include/linux/mmu_notifier.h
index a1a210d59961..e67867bec2d1 100644
--- a/include/linux/mmu_notifier.h
+++ b/include/linux/mmu_notifier.h
@@ -201,13 +201,13 @@ static inline int mm_has_notifiers(struct mm_struct *mm)
}

extern int mmu_notifier_register(struct mmu_notifier *mn,
- struct mm_struct *mm);
+ struct mm_struct *mm, struct mm_ref *mm_ref);
extern int __mmu_notifier_register(struct mmu_notifier *mn,
- struct mm_struct *mm);
+ struct mm_struct *mm, struct mm_ref *mm_ref);
extern void mmu_notifier_unregister(struct mmu_notifier *mn,
- struct mm_struct *mm);
+ struct mm_struct *mm, struct mm_ref *mm_ref);
extern void mmu_notifier_unregister_no_release(struct mmu_notifier *mn,
- struct mm_struct *mm);
+ struct mm_struct *mm, struct mm_ref *ref);
extern void __mmu_notifier_mm_destroy(struct mm_struct *mm);
extern void __mmu_notifier_release(struct mm_struct *mm);
extern int __mmu_notifier_clear_flush_young(struct mm_struct *mm,
diff --git a/include/linux/sched.h b/include/linux/sched.h
index 2ca3e15dad3b..293c64a15dfa 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -25,6 +25,7 @@ struct sched_param {
#include <linux/errno.h>
#include <linux/nodemask.h>
#include <linux/mm_types.h>
+#include <linux/mm_ref.h>
#include <linux/preempt.h>

#include <asm/page.h>
@@ -808,6 +809,7 @@ struct signal_struct {
* Only settable by CAP_SYS_RESOURCE. */
struct mm_struct *oom_mm; /* recorded mm when the thread group got
* killed by the oom killer */
+ struct mm_ref oom_mm_ref;

struct mutex cred_guard_mutex; /* guard against foreign influences on
* credential calculations
@@ -1546,7 +1548,16 @@ struct task_struct {
struct rb_node pushable_dl_tasks;
#endif

+ /*
+ * ->mm and ->active_mm share the mm_ref. Not ideal IMHO, but that's
+ * how it's done. For kernel threads, ->mm == NULL, and for user
+ * threads, ->mm == ->active_mm, so we only need one reference.
+ *
+ * See <Documentation/vm/active_mm.txt> for more information.
+ */
struct mm_struct *mm, *active_mm;
+ struct mm_ref mm_ref;
+
/* per-thread vma caching */
u32 vmacache_seqnum;
struct vm_area_struct *vmacache[VMACACHE_SIZE];
@@ -2639,6 +2650,7 @@ extern union thread_union init_thread_union;
extern struct task_struct init_task;

extern struct mm_struct init_mm;
+extern struct mm_ref init_mm_ref;

extern struct pid_namespace init_pid_ns;

@@ -2870,17 +2882,19 @@ static inline unsigned long sigsp(unsigned long sp, struct ksignal *ksig)
/*
* Routines for handling mm_structs
*/
-extern struct mm_struct * mm_alloc(void);
+extern struct mm_struct * mm_alloc(struct mm_ref *ref);

-static inline void mmgrab(struct mm_struct *mm)
+static inline void mmgrab(struct mm_struct *mm, struct mm_ref *ref)
{
atomic_inc(&mm->mm_count);
+ get_mm_ref(mm, ref);
}

/* mmdrop drops the mm and the page tables */
extern void __mmdrop(struct mm_struct *);
-static inline void mmdrop(struct mm_struct *mm)
+static inline void mmdrop(struct mm_struct *mm, struct mm_ref *ref)
{
+ put_mm_ref(mm, ref);
if (unlikely(atomic_dec_and_test(&mm->mm_count)))
__mmdrop(mm);
}
@@ -2891,41 +2905,47 @@ static inline void mmdrop_async_fn(struct work_struct *work)
__mmdrop(mm);
}

-static inline void mmdrop_async(struct mm_struct *mm)
+static inline void mmdrop_async(struct mm_struct *mm, struct mm_ref *ref)
{
+ put_mm_ref(mm, ref);
if (unlikely(atomic_dec_and_test(&mm->mm_count))) {
INIT_WORK(&mm->async_put_work, mmdrop_async_fn);
schedule_work(&mm->async_put_work);
}
}

-static inline void mmget(struct mm_struct *mm)
+static inline void mmget(struct mm_struct *mm, struct mm_ref *ref)
{
atomic_inc(&mm->mm_users);
+ get_mm_users_ref(mm, ref);
}

-static inline bool mmget_not_zero(struct mm_struct *mm)
+static inline bool mmget_not_zero(struct mm_struct *mm, struct mm_ref *ref)
{
- return atomic_inc_not_zero(&mm->mm_users);
+ bool not_zero = atomic_inc_not_zero(&mm->mm_users);
+ if (not_zero)
+ get_mm_users_ref(mm, ref);
+
+ return not_zero;
}

/* mmput gets rid of the mappings and all user-space */
-extern void mmput(struct mm_struct *);
+extern void mmput(struct mm_struct *, struct mm_ref *);
#ifdef CONFIG_MMU
/* same as above but performs the slow path from the async context. Can
* be called from the atomic context as well
*/
-extern void mmput_async(struct mm_struct *);
+extern void mmput_async(struct mm_struct *, struct mm_ref *ref);
#endif

/* Grab a reference to a task's mm, if it is not already going away */
-extern struct mm_struct *get_task_mm(struct task_struct *task);
+extern struct mm_struct *get_task_mm(struct task_struct *task, struct mm_ref *mm_ref);
/*
* Grab a reference to a task's mm, if it is not already going away
* and ptrace_may_access with the mode parameter passed to it
* succeeds.
*/
-extern struct mm_struct *mm_access(struct task_struct *task, unsigned int mode);
+extern struct mm_struct *mm_access(struct task_struct *task, unsigned int mode, struct mm_ref *mm_ref);
/* Remove the current tasks stale references to the old mm_struct */
extern void mm_release(struct task_struct *, struct mm_struct *);

diff --git a/kernel/cpuset.c b/kernel/cpuset.c
index 29f815d2ef7e..66c5778f4052 100644
--- a/kernel/cpuset.c
+++ b/kernel/cpuset.c
@@ -994,6 +994,7 @@ static int update_cpumask(struct cpuset *cs, struct cpuset *trialcs,
struct cpuset_migrate_mm_work {
struct work_struct work;
struct mm_struct *mm;
+ struct mm_ref mm_ref;
nodemask_t from;
nodemask_t to;
};
@@ -1005,24 +1006,25 @@ static void cpuset_migrate_mm_workfn(struct work_struct *work)

/* on a wq worker, no need to worry about %current's mems_allowed */
do_migrate_pages(mwork->mm, &mwork->from, &mwork->to, MPOL_MF_MOVE_ALL);
- mmput(mwork->mm);
+ mmput(mwork->mm, &mwork->mm_ref);
kfree(mwork);
}

static void cpuset_migrate_mm(struct mm_struct *mm, const nodemask_t *from,
- const nodemask_t *to)
+ const nodemask_t *to, struct mm_ref *mm_ref)
{
struct cpuset_migrate_mm_work *mwork;

mwork = kzalloc(sizeof(*mwork), GFP_KERNEL);
if (mwork) {
mwork->mm = mm;
+ move_mm_users_ref(mm, mm_ref, &mwork->mm_ref);
mwork->from = *from;
mwork->to = *to;
INIT_WORK(&mwork->work, cpuset_migrate_mm_workfn);
queue_work(cpuset_migrate_mm_wq, &mwork->work);
} else {
- mmput(mm);
+ mmput(mm, mm_ref);
}
}

@@ -1107,11 +1109,12 @@ static void update_tasks_nodemask(struct cpuset *cs)
css_task_iter_start(&cs->css, &it);
while ((task = css_task_iter_next(&it))) {
struct mm_struct *mm;
+ MM_REF(mm_ref);
bool migrate;

cpuset_change_task_nodemask(task, &newmems);

- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
if (!mm)
continue;

@@ -1119,9 +1122,9 @@ static void update_tasks_nodemask(struct cpuset *cs)

mpol_rebind_mm(mm, &cs->mems_allowed);
if (migrate)
- cpuset_migrate_mm(mm, &cs->old_mems_allowed, &newmems);
+ cpuset_migrate_mm(mm, &cs->old_mems_allowed, &newmems, &mm_ref);
else
- mmput(mm);
+ mmput(mm, &mm_ref);
}
css_task_iter_end(&it);

@@ -1556,7 +1559,8 @@ static void cpuset_attach(struct cgroup_taskset *tset)
*/
cpuset_attach_nodemask_to = cs->effective_mems;
cgroup_taskset_for_each_leader(leader, css, tset) {
- struct mm_struct *mm = get_task_mm(leader);
+ MM_REF(mm_ref);
+ struct mm_struct *mm = get_task_mm(leader, &mm_ref);

if (mm) {
mpol_rebind_mm(mm, &cpuset_attach_nodemask_to);
@@ -1571,9 +1575,9 @@ static void cpuset_attach(struct cgroup_taskset *tset)
*/
if (is_memory_migrate(cs))
cpuset_migrate_mm(mm, &oldcs->old_mems_allowed,
- &cpuset_attach_nodemask_to);
+ &cpuset_attach_nodemask_to, &mm_ref);
else
- mmput(mm);
+ mmput(mm, &mm_ref);
}
}

diff --git a/kernel/events/core.c b/kernel/events/core.c
index 02c8421f8c01..2909d6db3b7a 100644
--- a/kernel/events/core.c
+++ b/kernel/events/core.c
@@ -7965,6 +7965,7 @@ static void perf_event_addr_filters_apply(struct perf_event *event)
struct task_struct *task = READ_ONCE(event->ctx->task);
struct perf_addr_filter *filter;
struct mm_struct *mm = NULL;
+ MM_REF(mm_ref);
unsigned int count = 0;
unsigned long flags;

@@ -7975,7 +7976,7 @@ static void perf_event_addr_filters_apply(struct perf_event *event)
if (task == TASK_TOMBSTONE)
return;

- mm = get_task_mm(event->ctx->task);
+ mm = get_task_mm(event->ctx->task, &mm_ref);
if (!mm)
goto restart;

@@ -8001,7 +8002,7 @@ static void perf_event_addr_filters_apply(struct perf_event *event)

up_read(&mm->mmap_sem);

- mmput(mm);
+ mmput(mm, &mm_ref);

restart:
perf_event_stop(event, 1);
diff --git a/kernel/exit.c b/kernel/exit.c
index b12753840050..d367ef9bcfe6 100644
--- a/kernel/exit.c
+++ b/kernel/exit.c
@@ -462,6 +462,7 @@ void mm_update_next_owner(struct mm_struct *mm)
static void exit_mm(struct task_struct *tsk)
{
struct mm_struct *mm = tsk->mm;
+ MM_REF(mm_ref);
struct core_state *core_state;

mm_release(tsk, mm);
@@ -500,7 +501,7 @@ static void exit_mm(struct task_struct *tsk)
__set_task_state(tsk, TASK_RUNNING);
down_read(&mm->mmap_sem);
}
- mmgrab(mm);
+ mmgrab(mm, &mm_ref);
BUG_ON(mm != tsk->active_mm);
/* more a memory barrier than a real lock */
task_lock(tsk);
@@ -509,7 +510,7 @@ static void exit_mm(struct task_struct *tsk)
enter_lazy_tlb(mm, current);
task_unlock(tsk);
mm_update_next_owner(mm);
- mmput(mm);
+ mmput(mm, &mm_ref);
if (test_thread_flag(TIF_MEMDIE))
exit_oom_victim();
}
diff --git a/kernel/fork.c b/kernel/fork.c
index f9c32dc6ccbc..a431a52375d7 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -367,7 +367,7 @@ static inline void free_signal_struct(struct signal_struct *sig)
* pgd_dtor so postpone it to the async context
*/
if (sig->oom_mm)
- mmdrop_async(sig->oom_mm);
+ mmdrop_async(sig->oom_mm, &sig->oom_mm_ref);
kmem_cache_free(signal_cachep, sig);
}

@@ -745,13 +745,22 @@ static void mm_init_owner(struct mm_struct *mm, struct task_struct *p)
#endif
}

-static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p)
+static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p, struct mm_ref *mm_ref)
{
mm->mmap = NULL;
mm->mm_rb = RB_ROOT;
mm->vmacache_seqnum = 0;
+
atomic_set(&mm->mm_users, 1);
+ INIT_LIST_HEAD(&mm->mm_users_list);
+ INIT_MM_REF(&mm->mm_users_ref);
+
atomic_set(&mm->mm_count, 1);
+ INIT_LIST_HEAD(&mm->mm_count_list);
+
+ get_mm_ref(mm, mm_ref);
+ get_mm_users_ref(mm, &mm->mm_users_ref);
+
init_rwsem(&mm->mmap_sem);
INIT_LIST_HEAD(&mm->mmlist);
mm->core_state = NULL;
@@ -821,7 +830,7 @@ static void check_mm(struct mm_struct *mm)
/*
* Allocate and initialize an mm_struct.
*/
-struct mm_struct *mm_alloc(void)
+struct mm_struct *mm_alloc(struct mm_ref *ref)
{
struct mm_struct *mm;

@@ -830,7 +839,7 @@ struct mm_struct *mm_alloc(void)
return NULL;

memset(mm, 0, sizeof(*mm));
- return mm_init(mm, current);
+ return mm_init(mm, current, ref);
}

/*
@@ -868,16 +877,17 @@ static inline void __mmput(struct mm_struct *mm)
if (mm->binfmt)
module_put(mm->binfmt->module);
set_bit(MMF_OOM_SKIP, &mm->flags);
- mmdrop(mm);
+ mmdrop(mm, &mm->mm_users_ref);
}

/*
* Decrement the use count and release all resources for an mm.
*/
-void mmput(struct mm_struct *mm)
+void mmput(struct mm_struct *mm, struct mm_ref *ref)
{
might_sleep();

+ put_mm_users_ref(mm, ref);
if (atomic_dec_and_test(&mm->mm_users))
__mmput(mm);
}
@@ -890,8 +900,9 @@ static void mmput_async_fn(struct work_struct *work)
__mmput(mm);
}

-void mmput_async(struct mm_struct *mm)
+void mmput_async(struct mm_struct *mm, struct mm_ref *ref)
{
+ put_mm_users_ref(mm, ref);
if (atomic_dec_and_test(&mm->mm_users)) {
INIT_WORK(&mm->async_put_work, mmput_async_fn);
schedule_work(&mm->async_put_work);
@@ -979,7 +990,7 @@ EXPORT_SYMBOL(get_task_exe_file);
* bumping up the use count. User must release the mm via mmput()
* after use. Typically used by /proc and ptrace.
*/
-struct mm_struct *get_task_mm(struct task_struct *task)
+struct mm_struct *get_task_mm(struct task_struct *task, struct mm_ref *mm_ref)
{
struct mm_struct *mm;

@@ -989,14 +1000,14 @@ struct mm_struct *get_task_mm(struct task_struct *task)
if (task->flags & PF_KTHREAD)
mm = NULL;
else
- mmget(mm);
+ mmget(mm, mm_ref);
}
task_unlock(task);
return mm;
}
EXPORT_SYMBOL_GPL(get_task_mm);

-struct mm_struct *mm_access(struct task_struct *task, unsigned int mode)
+struct mm_struct *mm_access(struct task_struct *task, unsigned int mode, struct mm_ref *mm_ref)
{
struct mm_struct *mm;
int err;
@@ -1005,10 +1016,10 @@ struct mm_struct *mm_access(struct task_struct *task, unsigned int mode)
if (err)
return ERR_PTR(err);

- mm = get_task_mm(task);
+ mm = get_task_mm(task, mm_ref);
if (mm && mm != current->mm &&
!ptrace_may_access(task, mode)) {
- mmput(mm);
+ mmput(mm, mm_ref);
mm = ERR_PTR(-EACCES);
}
mutex_unlock(&task->signal->cred_guard_mutex);
@@ -1115,7 +1126,7 @@ void mm_release(struct task_struct *tsk, struct mm_struct *mm)
* Allocate a new mm structure and copy contents from the
* mm structure of the passed in task structure.
*/
-static struct mm_struct *dup_mm(struct task_struct *tsk)
+static struct mm_struct *dup_mm(struct task_struct *tsk, struct mm_ref *ref)
{
struct mm_struct *mm, *oldmm = current->mm;
int err;
@@ -1126,7 +1137,7 @@ static struct mm_struct *dup_mm(struct task_struct *tsk)

memcpy(mm, oldmm, sizeof(*mm));

- if (!mm_init(mm, tsk))
+ if (!mm_init(mm, tsk, ref))
goto fail_nomem;

err = dup_mmap(mm, oldmm);
@@ -1144,7 +1155,7 @@ static struct mm_struct *dup_mm(struct task_struct *tsk)
free_pt:
/* don't put binfmt in mmput, we haven't got module yet */
mm->binfmt = NULL;
- mmput(mm);
+ mmput(mm, ref);

fail_nomem:
return NULL;
@@ -1163,6 +1174,7 @@ static int copy_mm(unsigned long clone_flags, struct task_struct *tsk)

tsk->mm = NULL;
tsk->active_mm = NULL;
+ INIT_MM_REF(&tsk->mm_ref);

/*
* Are we cloning a kernel thread?
@@ -1177,13 +1189,13 @@ static int copy_mm(unsigned long clone_flags, struct task_struct *tsk)
vmacache_flush(tsk);

if (clone_flags & CLONE_VM) {
- mmget(oldmm);
+ mmget(oldmm, &tsk->mm_ref);
mm = oldmm;
goto good_mm;
}

retval = -ENOMEM;
- mm = dup_mm(tsk);
+ mm = dup_mm(tsk, &tsk->mm_ref);
if (!mm)
goto fail_nomem;

@@ -1360,6 +1372,9 @@ static int copy_signal(unsigned long clone_flags, struct task_struct *tsk)
sig->oom_score_adj = current->signal->oom_score_adj;
sig->oom_score_adj_min = current->signal->oom_score_adj_min;

+ sig->oom_mm = NULL;
+ INIT_MM_REF(&sig->oom_mm_ref);
+
sig->has_child_subreaper = current->signal->has_child_subreaper ||
current->signal->is_child_subreaper;

@@ -1839,7 +1854,7 @@ static __latent_entropy struct task_struct *copy_process(
exit_task_namespaces(p);
bad_fork_cleanup_mm:
if (p->mm)
- mmput(p->mm);
+ mmput(p->mm, &p->mm_ref);
bad_fork_cleanup_signal:
if (!(clone_flags & CLONE_THREAD))
free_signal_struct(p->signal);
diff --git a/kernel/futex.c b/kernel/futex.c
index cbe6056c17c1..3a279ee2166b 100644
--- a/kernel/futex.c
+++ b/kernel/futex.c
@@ -240,6 +240,7 @@ struct futex_q {
struct task_struct *task;
spinlock_t *lock_ptr;
union futex_key key;
+ struct mm_ref mm_ref;
struct futex_pi_state *pi_state;
struct rt_mutex_waiter *rt_waiter;
union futex_key *requeue_pi_key;
@@ -249,6 +250,7 @@ struct futex_q {
static const struct futex_q futex_q_init = {
/* list gets initialized in queue_me()*/
.key = FUTEX_KEY_INIT,
+ /* .mm_ref must be initialized for each futex_q */
.bitset = FUTEX_BITSET_MATCH_ANY
};

@@ -336,9 +338,9 @@ static inline bool should_fail_futex(bool fshared)
}
#endif /* CONFIG_FAIL_FUTEX */

-static inline void futex_get_mm(union futex_key *key)
+static inline void futex_get_mm(union futex_key *key, struct mm_ref *ref)
{
- mmgrab(key->private.mm);
+ mmgrab(key->private.mm, ref);
/*
* Ensure futex_get_mm() implies a full barrier such that
* get_futex_key() implies a full barrier. This is relied upon
@@ -417,7 +419,7 @@ static inline int match_futex(union futex_key *key1, union futex_key *key2)
* Can be called while holding spinlocks.
*
*/
-static void get_futex_key_refs(union futex_key *key)
+static void get_futex_key_refs(union futex_key *key, struct mm_ref *ref)
{
if (!key->both.ptr)
return;
@@ -437,7 +439,7 @@ static void get_futex_key_refs(union futex_key *key)
ihold(key->shared.inode); /* implies smp_mb(); (B) */
break;
case FUT_OFF_MMSHARED:
- futex_get_mm(key); /* implies smp_mb(); (B) */
+ futex_get_mm(key, ref); /* implies smp_mb(); (B) */
break;
default:
/*
@@ -455,7 +457,7 @@ static void get_futex_key_refs(union futex_key *key)
* a no-op for private futexes, see comment in the get
* counterpart.
*/
-static void drop_futex_key_refs(union futex_key *key)
+static void drop_futex_key_refs(union futex_key *key, struct mm_ref *ref)
{
if (!key->both.ptr) {
/* If we're here then we tried to put a key we failed to get */
@@ -471,7 +473,7 @@ static void drop_futex_key_refs(union futex_key *key)
iput(key->shared.inode);
break;
case FUT_OFF_MMSHARED:
- mmdrop(key->private.mm);
+ mmdrop(key->private.mm, ref);
break;
}
}
@@ -495,7 +497,7 @@ static void drop_futex_key_refs(union futex_key *key)
* lock_page() might sleep, the caller should not hold a spinlock.
*/
static int
-get_futex_key(u32 __user *uaddr, int fshared, union futex_key *key, int rw)
+get_futex_key(u32 __user *uaddr, int fshared, union futex_key *key, int rw, struct mm_ref *mm_ref)
{
unsigned long address = (unsigned long)uaddr;
struct mm_struct *mm = current->mm;
@@ -527,7 +529,7 @@ get_futex_key(u32 __user *uaddr, int fshared, union futex_key *key, int rw)
if (!fshared) {
key->private.mm = mm;
key->private.address = address;
- get_futex_key_refs(key); /* implies smp_mb(); (B) */
+ get_futex_key_refs(key, mm_ref); /* implies smp_mb(); (B) */
return 0;
}

@@ -630,7 +632,7 @@ get_futex_key(u32 __user *uaddr, int fshared, union futex_key *key, int rw)
key->private.mm = mm;
key->private.address = address;

- get_futex_key_refs(key); /* implies smp_mb(); (B) */
+ get_futex_key_refs(key, mm_ref); /* implies smp_mb(); (B) */

} else {
struct inode *inode;
@@ -701,9 +703,9 @@ get_futex_key(u32 __user *uaddr, int fshared, union futex_key *key, int rw)
return err;
}

-static inline void put_futex_key(union futex_key *key)
+static inline void put_futex_key(union futex_key *key, struct mm_ref *mm_ref)
{
- drop_futex_key_refs(key);
+ drop_futex_key_refs(key, mm_ref);
}

/**
@@ -1414,13 +1416,14 @@ futex_wake(u32 __user *uaddr, unsigned int flags, int nr_wake, u32 bitset)
struct futex_hash_bucket *hb;
struct futex_q *this, *next;
union futex_key key = FUTEX_KEY_INIT;
+ MM_REF(mm_ref);
int ret;
WAKE_Q(wake_q);

if (!bitset)
return -EINVAL;

- ret = get_futex_key(uaddr, flags & FLAGS_SHARED, &key, VERIFY_READ);
+ ret = get_futex_key(uaddr, flags & FLAGS_SHARED, &key, VERIFY_READ, &mm_ref);
if (unlikely(ret != 0))
goto out;

@@ -1452,7 +1455,7 @@ futex_wake(u32 __user *uaddr, unsigned int flags, int nr_wake, u32 bitset)
spin_unlock(&hb->lock);
wake_up_q(&wake_q);
out_put_key:
- put_futex_key(&key);
+ put_futex_key(&key, &mm_ref);
out:
return ret;
}
@@ -1466,16 +1469,18 @@ futex_wake_op(u32 __user *uaddr1, unsigned int flags, u32 __user *uaddr2,
int nr_wake, int nr_wake2, int op)
{
union futex_key key1 = FUTEX_KEY_INIT, key2 = FUTEX_KEY_INIT;
+ MM_REF(mm_ref1);
+ MM_REF(mm_ref2);
struct futex_hash_bucket *hb1, *hb2;
struct futex_q *this, *next;
int ret, op_ret;
WAKE_Q(wake_q);

retry:
- ret = get_futex_key(uaddr1, flags & FLAGS_SHARED, &key1, VERIFY_READ);
+ ret = get_futex_key(uaddr1, flags & FLAGS_SHARED, &key1, VERIFY_READ, &mm_ref1);
if (unlikely(ret != 0))
goto out;
- ret = get_futex_key(uaddr2, flags & FLAGS_SHARED, &key2, VERIFY_WRITE);
+ ret = get_futex_key(uaddr2, flags & FLAGS_SHARED, &key2, VERIFY_WRITE, &mm_ref2);
if (unlikely(ret != 0))
goto out_put_key1;

@@ -1510,8 +1515,8 @@ futex_wake_op(u32 __user *uaddr1, unsigned int flags, u32 __user *uaddr2,
if (!(flags & FLAGS_SHARED))
goto retry_private;

- put_futex_key(&key2);
- put_futex_key(&key1);
+ put_futex_key(&key2, &mm_ref2);
+ put_futex_key(&key1, &mm_ref1);
goto retry;
}

@@ -1547,9 +1552,9 @@ futex_wake_op(u32 __user *uaddr1, unsigned int flags, u32 __user *uaddr2,
double_unlock_hb(hb1, hb2);
wake_up_q(&wake_q);
out_put_keys:
- put_futex_key(&key2);
+ put_futex_key(&key2, &mm_ref2);
out_put_key1:
- put_futex_key(&key1);
+ put_futex_key(&key1, &mm_ref1);
out:
return ret;
}
@@ -1563,7 +1568,7 @@ futex_wake_op(u32 __user *uaddr1, unsigned int flags, u32 __user *uaddr2,
*/
static inline
void requeue_futex(struct futex_q *q, struct futex_hash_bucket *hb1,
- struct futex_hash_bucket *hb2, union futex_key *key2)
+ struct futex_hash_bucket *hb2, union futex_key *key2, struct mm_ref *mm_ref2)
{

/*
@@ -1577,7 +1582,7 @@ void requeue_futex(struct futex_q *q, struct futex_hash_bucket *hb1,
plist_add(&q->list, &hb2->chain);
q->lock_ptr = &hb2->lock;
}
- get_futex_key_refs(key2);
+ get_futex_key_refs(key2, mm_ref2);
q->key = *key2;
}

@@ -1597,9 +1602,9 @@ void requeue_futex(struct futex_q *q, struct futex_hash_bucket *hb1,
*/
static inline
void requeue_pi_wake_futex(struct futex_q *q, union futex_key *key,
- struct futex_hash_bucket *hb)
+ struct futex_hash_bucket *hb, struct mm_ref *ref)
{
- get_futex_key_refs(key);
+ get_futex_key_refs(key, ref);
q->key = *key;

__unqueue_futex(q);
@@ -1636,7 +1641,8 @@ static int futex_proxy_trylock_atomic(u32 __user *pifutex,
struct futex_hash_bucket *hb1,
struct futex_hash_bucket *hb2,
union futex_key *key1, union futex_key *key2,
- struct futex_pi_state **ps, int set_waiters)
+ struct futex_pi_state **ps, int set_waiters,
+ struct mm_ref *mm_ref2)
{
struct futex_q *top_waiter = NULL;
u32 curval;
@@ -1675,7 +1681,7 @@ static int futex_proxy_trylock_atomic(u32 __user *pifutex,
ret = futex_lock_pi_atomic(pifutex, hb2, key2, ps, top_waiter->task,
set_waiters);
if (ret == 1) {
- requeue_pi_wake_futex(top_waiter, key2, hb2);
+ requeue_pi_wake_futex(top_waiter, key2, hb2, mm_ref2);
return vpid;
}
return ret;
@@ -1704,6 +1710,8 @@ static int futex_requeue(u32 __user *uaddr1, unsigned int flags,
u32 *cmpval, int requeue_pi)
{
union futex_key key1 = FUTEX_KEY_INIT, key2 = FUTEX_KEY_INIT;
+ MM_REF(mm_ref1);
+ MM_REF(mm_ref2);
int drop_count = 0, task_count = 0, ret;
struct futex_pi_state *pi_state = NULL;
struct futex_hash_bucket *hb1, *hb2;
@@ -1739,11 +1747,11 @@ static int futex_requeue(u32 __user *uaddr1, unsigned int flags,
}

retry:
- ret = get_futex_key(uaddr1, flags & FLAGS_SHARED, &key1, VERIFY_READ);
+ ret = get_futex_key(uaddr1, flags & FLAGS_SHARED, &key1, VERIFY_READ, &mm_ref1);
if (unlikely(ret != 0))
goto out;
ret = get_futex_key(uaddr2, flags & FLAGS_SHARED, &key2,
- requeue_pi ? VERIFY_WRITE : VERIFY_READ);
+ requeue_pi ? VERIFY_WRITE : VERIFY_READ, &mm_ref2);
if (unlikely(ret != 0))
goto out_put_key1;

@@ -1779,8 +1787,8 @@ static int futex_requeue(u32 __user *uaddr1, unsigned int flags,
if (!(flags & FLAGS_SHARED))
goto retry_private;

- put_futex_key(&key2);
- put_futex_key(&key1);
+ put_futex_key(&key2, &mm_ref2);
+ put_futex_key(&key1, &mm_ref1);
goto retry;
}
if (curval != *cmpval) {
@@ -1797,7 +1805,7 @@ static int futex_requeue(u32 __user *uaddr1, unsigned int flags,
* faults rather in the requeue loop below.
*/
ret = futex_proxy_trylock_atomic(uaddr2, hb1, hb2, &key1,
- &key2, &pi_state, nr_requeue);
+ &key2, &pi_state, nr_requeue, &mm_ref2);

/*
* At this point the top_waiter has either taken uaddr2 or is
@@ -1836,8 +1844,8 @@ static int futex_requeue(u32 __user *uaddr1, unsigned int flags,
case -EFAULT:
double_unlock_hb(hb1, hb2);
hb_waiters_dec(hb2);
- put_futex_key(&key2);
- put_futex_key(&key1);
+ put_futex_key(&key2, &mm_ref2);
+ put_futex_key(&key1, &mm_ref1);
ret = fault_in_user_writeable(uaddr2);
if (!ret)
goto retry;
@@ -1851,8 +1859,8 @@ static int futex_requeue(u32 __user *uaddr1, unsigned int flags,
*/
double_unlock_hb(hb1, hb2);
hb_waiters_dec(hb2);
- put_futex_key(&key2);
- put_futex_key(&key1);
+ put_futex_key(&key2, &mm_ref2);
+ put_futex_key(&key1, &mm_ref1);
cond_resched();
goto retry;
default:
@@ -1921,7 +1929,7 @@ static int futex_requeue(u32 __user *uaddr1, unsigned int flags,
* value. It will drop the refcount after
* doing so.
*/
- requeue_pi_wake_futex(this, &key2, hb2);
+ requeue_pi_wake_futex(this, &key2, hb2, &mm_ref2);
drop_count++;
continue;
} else if (ret) {
@@ -1942,7 +1950,7 @@ static int futex_requeue(u32 __user *uaddr1, unsigned int flags,
break;
}
}
- requeue_futex(this, hb1, hb2, &key2);
+ requeue_futex(this, hb1, hb2, &key2, &mm_ref2);
drop_count++;
}

@@ -1965,12 +1973,12 @@ static int futex_requeue(u32 __user *uaddr1, unsigned int flags,
* hold the references to key1.
*/
while (--drop_count >= 0)
- drop_futex_key_refs(&key1);
+ drop_futex_key_refs(&key1, &mm_ref1);

out_put_keys:
- put_futex_key(&key2);
+ put_futex_key(&key2, &mm_ref2);
out_put_key1:
- put_futex_key(&key1);
+ put_futex_key(&key1, &mm_ref1);
out:
return ret ? ret : task_count;
}
@@ -2091,7 +2099,7 @@ static int unqueue_me(struct futex_q *q)
ret = 1;
}

- drop_futex_key_refs(&q->key);
+ drop_futex_key_refs(&q->key, &q->mm_ref);
return ret;
}

@@ -2365,7 +2373,7 @@ static int futex_wait_setup(u32 __user *uaddr, u32 val, unsigned int flags,
* while the syscall executes.
*/
retry:
- ret = get_futex_key(uaddr, flags & FLAGS_SHARED, &q->key, VERIFY_READ);
+ ret = get_futex_key(uaddr, flags & FLAGS_SHARED, &q->key, VERIFY_READ, &q->mm_ref);
if (unlikely(ret != 0))
return ret;

@@ -2384,7 +2392,7 @@ static int futex_wait_setup(u32 __user *uaddr, u32 val, unsigned int flags,
if (!(flags & FLAGS_SHARED))
goto retry_private;

- put_futex_key(&q->key);
+ put_futex_key(&q->key, &q->mm_ref);
goto retry;
}

@@ -2395,7 +2403,7 @@ static int futex_wait_setup(u32 __user *uaddr, u32 val, unsigned int flags,

out:
if (ret)
- put_futex_key(&q->key);
+ put_futex_key(&q->key, &q->mm_ref);
return ret;
}

@@ -2408,6 +2416,8 @@ static int futex_wait(u32 __user *uaddr, unsigned int flags, u32 val,
struct futex_q q = futex_q_init;
int ret;

+ INIT_MM_REF(&q.mm_ref);
+
if (!bitset)
return -EINVAL;
q.bitset = bitset;
@@ -2507,6 +2517,8 @@ static int futex_lock_pi(u32 __user *uaddr, unsigned int flags,
struct futex_q q = futex_q_init;
int res, ret;

+ INIT_MM_REF(&q.mm_ref);
+
if (refill_pi_state_cache())
return -ENOMEM;

@@ -2519,7 +2531,7 @@ static int futex_lock_pi(u32 __user *uaddr, unsigned int flags,
}

retry:
- ret = get_futex_key(uaddr, flags & FLAGS_SHARED, &q.key, VERIFY_WRITE);
+ ret = get_futex_key(uaddr, flags & FLAGS_SHARED, &q.key, VERIFY_WRITE, &q.mm_ref);
if (unlikely(ret != 0))
goto out;

@@ -2547,7 +2559,7 @@ static int futex_lock_pi(u32 __user *uaddr, unsigned int flags,
* - The user space value changed.
*/
queue_unlock(hb);
- put_futex_key(&q.key);
+ put_futex_key(&q.key, &q.mm_ref);
cond_resched();
goto retry;
default:
@@ -2601,7 +2613,7 @@ static int futex_lock_pi(u32 __user *uaddr, unsigned int flags,
queue_unlock(hb);

out_put_key:
- put_futex_key(&q.key);
+ put_futex_key(&q.key, &q.mm_ref);
out:
if (to)
destroy_hrtimer_on_stack(&to->timer);
@@ -2617,7 +2629,7 @@ static int futex_lock_pi(u32 __user *uaddr, unsigned int flags,
if (!(flags & FLAGS_SHARED))
goto retry_private;

- put_futex_key(&q.key);
+ put_futex_key(&q.key, &q.mm_ref);
goto retry;
}

@@ -2630,6 +2642,7 @@ static int futex_unlock_pi(u32 __user *uaddr, unsigned int flags)
{
u32 uninitialized_var(curval), uval, vpid = task_pid_vnr(current);
union futex_key key = FUTEX_KEY_INIT;
+ MM_REF(mm_ref);
struct futex_hash_bucket *hb;
struct futex_q *match;
int ret;
@@ -2643,7 +2656,7 @@ static int futex_unlock_pi(u32 __user *uaddr, unsigned int flags)
if ((uval & FUTEX_TID_MASK) != vpid)
return -EPERM;

- ret = get_futex_key(uaddr, flags & FLAGS_SHARED, &key, VERIFY_WRITE);
+ ret = get_futex_key(uaddr, flags & FLAGS_SHARED, &key, VERIFY_WRITE, &mm_ref);
if (ret)
return ret;

@@ -2676,7 +2689,7 @@ static int futex_unlock_pi(u32 __user *uaddr, unsigned int flags)
*/
if (ret == -EAGAIN) {
spin_unlock(&hb->lock);
- put_futex_key(&key);
+ put_futex_key(&key, &mm_ref);
goto retry;
}
/*
@@ -2704,12 +2717,12 @@ static int futex_unlock_pi(u32 __user *uaddr, unsigned int flags)
out_unlock:
spin_unlock(&hb->lock);
out_putkey:
- put_futex_key(&key);
+ put_futex_key(&key, &mm_ref);
return ret;

pi_faulted:
spin_unlock(&hb->lock);
- put_futex_key(&key);
+ put_futex_key(&key, &mm_ref);

ret = fault_in_user_writeable(uaddr);
if (!ret)
@@ -2816,9 +2829,12 @@ static int futex_wait_requeue_pi(u32 __user *uaddr, unsigned int flags,
struct rt_mutex *pi_mutex = NULL;
struct futex_hash_bucket *hb;
union futex_key key2 = FUTEX_KEY_INIT;
+ MM_REF(mm_ref2);
struct futex_q q = futex_q_init;
int res, ret;

+ INIT_MM_REF(&q.mm_ref);
+
if (uaddr == uaddr2)
return -EINVAL;

@@ -2844,7 +2860,7 @@ static int futex_wait_requeue_pi(u32 __user *uaddr, unsigned int flags,
RB_CLEAR_NODE(&rt_waiter.tree_entry);
rt_waiter.task = NULL;

- ret = get_futex_key(uaddr2, flags & FLAGS_SHARED, &key2, VERIFY_WRITE);
+ ret = get_futex_key(uaddr2, flags & FLAGS_SHARED, &key2, VERIFY_WRITE, &mm_ref2);
if (unlikely(ret != 0))
goto out;

@@ -2951,9 +2967,9 @@ static int futex_wait_requeue_pi(u32 __user *uaddr, unsigned int flags,
}

out_put_keys:
- put_futex_key(&q.key);
+ put_futex_key(&q.key, &q.mm_ref);
out_key2:
- put_futex_key(&key2);
+ put_futex_key(&key2, &mm_ref2);

out:
if (to) {
diff --git a/kernel/sched/core.c b/kernel/sched/core.c
index ee1fb0070544..460c57d0d9af 100644
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -2771,7 +2771,7 @@ static struct rq *finish_task_switch(struct task_struct *prev)

fire_sched_in_preempt_notifiers(current);
if (mm)
- mmdrop(mm);
+ mmdrop(mm, &rq->prev_mm_ref);
if (unlikely(prev_state == TASK_DEAD)) {
if (prev->sched_class->task_dead)
prev->sched_class->task_dead(prev);
@@ -2877,12 +2877,14 @@ context_switch(struct rq *rq, struct task_struct *prev,

if (!mm) {
next->active_mm = oldmm;
- mmgrab(oldmm);
+ mmgrab(oldmm, &next->mm_ref);
enter_lazy_tlb(oldmm, next);
} else
switch_mm_irqs_off(oldmm, mm, next);

if (!prev->mm) {
+ if (oldmm)
+ move_mm_ref(oldmm, &prev->mm_ref, &rq->prev_mm_ref);
prev->active_mm = NULL;
rq->prev_mm = oldmm;
}
@@ -5472,7 +5474,7 @@ void idle_task_exit(void)
switch_mm_irqs_off(mm, &init_mm, current);
finish_arch_post_lock_switch();
}
- mmdrop(mm);
+ mmdrop(mm, &current->mm_ref);
}

/*
@@ -7640,6 +7642,10 @@ void __init sched_init(void)
rq->balance_callback = NULL;
rq->active_balance = 0;
rq->next_balance = jiffies;
+
+ BUG_ON(rq->prev_mm != NULL);
+ INIT_MM_REF(&rq->prev_mm_ref);
+
rq->push_cpu = 0;
rq->cpu = i;
rq->online = 0;
@@ -7667,7 +7673,7 @@ void __init sched_init(void)
/*
* The boot idle thread does lazy MMU switching as well:
*/
- mmgrab(&init_mm);
+ mmgrab(&init_mm, &init_mm_ref);
enter_lazy_tlb(&init_mm, current);

/*
diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h
index 055f935d4421..98680abb882a 100644
--- a/kernel/sched/sched.h
+++ b/kernel/sched/sched.h
@@ -636,6 +636,7 @@ struct rq {
struct task_struct *curr, *idle, *stop;
unsigned long next_balance;
struct mm_struct *prev_mm;
+ struct mm_ref prev_mm_ref;

unsigned int clock_skip_update;
u64 clock;
diff --git a/kernel/sys.c b/kernel/sys.c
index 89d5be418157..01a5bd227a53 100644
--- a/kernel/sys.c
+++ b/kernel/sys.c
@@ -1603,11 +1603,12 @@ static void k_getrusage(struct task_struct *p, int who, struct rusage *r)
cputime_to_timeval(stime, &r->ru_stime);

if (who != RUSAGE_CHILDREN) {
- struct mm_struct *mm = get_task_mm(p);
+ MM_REF(mm_ref);
+ struct mm_struct *mm = get_task_mm(p, &mm_ref);

if (mm) {
setmax_mm_hiwater_rss(&maxrss, mm);
- mmput(mm);
+ mmput(mm, &mm_ref);
}
}
r->ru_maxrss = maxrss * (PAGE_SIZE / 1024); /* convert pages to KBs */
diff --git a/kernel/trace/trace_output.c b/kernel/trace/trace_output.c
index 3fc20422c166..decd72ec58e1 100644
--- a/kernel/trace/trace_output.c
+++ b/kernel/trace/trace_output.c
@@ -1046,6 +1046,7 @@ static enum print_line_t trace_user_stack_print(struct trace_iterator *iter,
struct userstack_entry *field;
struct trace_seq *s = &iter->seq;
struct mm_struct *mm = NULL;
+ MM_REF(mm_ref);
unsigned int i;

trace_assign_type(field, iter->ent);
@@ -1061,7 +1062,7 @@ static enum print_line_t trace_user_stack_print(struct trace_iterator *iter,
rcu_read_lock();
task = find_task_by_vpid(field->tgid);
if (task)
- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
rcu_read_unlock();
}

@@ -1084,7 +1085,7 @@ static enum print_line_t trace_user_stack_print(struct trace_iterator *iter,
}

if (mm)
- mmput(mm);
+ mmput(mm, &mm_ref);

return trace_handle_return(s);
}
diff --git a/kernel/tsacct.c b/kernel/tsacct.c
index f8e26ab963ed..58595a3dca3f 100644
--- a/kernel/tsacct.c
+++ b/kernel/tsacct.c
@@ -92,18 +92,19 @@ void bacct_add_tsk(struct user_namespace *user_ns,
void xacct_add_tsk(struct taskstats *stats, struct task_struct *p)
{
struct mm_struct *mm;
+ MM_REF(mm_ref);

/* convert pages-nsec/1024 to Mbyte-usec, see __acct_update_integrals */
stats->coremem = p->acct_rss_mem1 * PAGE_SIZE;
do_div(stats->coremem, 1000 * KB);
stats->virtmem = p->acct_vm_mem1 * PAGE_SIZE;
do_div(stats->virtmem, 1000 * KB);
- mm = get_task_mm(p);
+ mm = get_task_mm(p, &mm_ref);
if (mm) {
/* adjust to KB unit */
stats->hiwater_rss = get_mm_hiwater_rss(mm) * PAGE_SIZE / KB;
stats->hiwater_vm = get_mm_hiwater_vm(mm) * PAGE_SIZE / KB;
- mmput(mm);
+ mmput(mm, &mm_ref);
}
stats->read_char = p->ioac.rchar & KB_MASK;
stats->write_char = p->ioac.wchar & KB_MASK;
diff --git a/mm/Makefile b/mm/Makefile
index 295bd7a9f76b..1d6acdf0a4a7 100644
--- a/mm/Makefile
+++ b/mm/Makefile
@@ -37,7 +37,7 @@ obj-y := filemap.o mempool.o oom_kill.o \
mm_init.o mmu_context.o percpu.o slab_common.o \
compaction.o vmacache.o \
interval_tree.o list_lru.o workingset.o \
- debug.o $(mmu-y)
+ debug.o mm_ref.o $(mmu-y)

obj-y += init-mm.o

diff --git a/mm/init-mm.c b/mm/init-mm.c
index a56a851908d2..deb315a4c240 100644
--- a/mm/init-mm.c
+++ b/mm/init-mm.c
@@ -16,10 +16,16 @@
struct mm_struct init_mm = {
.mm_rb = RB_ROOT,
.pgd = swapper_pg_dir,
+ .mm_refs_lock = __SPIN_LOCK_UNLOCKED(init_mm.mm_refs_lock),
.mm_users = ATOMIC_INIT(2),
+ .mm_users_list = LIST_HEAD_INIT(init_mm.mm_users_list),
+ .mm_users_ref = MM_REF_INIT(init_mm.mm_users_ref),
.mm_count = ATOMIC_INIT(1),
+ .mm_count_list = LIST_HEAD_INIT(init_mm.mm_count_list),
.mmap_sem = __RWSEM_INITIALIZER(init_mm.mmap_sem),
.page_table_lock = __SPIN_LOCK_UNLOCKED(init_mm.page_table_lock),
.mmlist = LIST_HEAD_INIT(init_mm.mmlist),
INIT_MM_CONTEXT(init_mm)
};
+
+MM_REF(init_mm_ref);
diff --git a/mm/memory.c b/mm/memory.c
index e18c57bdc75c..3be253b54c04 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -3954,15 +3954,16 @@ int access_process_vm(struct task_struct *tsk, unsigned long addr,
void *buf, int len, unsigned int gup_flags)
{
struct mm_struct *mm;
+ MM_REF(mm_ref);
int ret;

- mm = get_task_mm(tsk);
+ mm = get_task_mm(tsk, &mm_ref);
if (!mm)
return 0;

ret = __access_remote_vm(tsk, mm, addr, buf, len, gup_flags);

- mmput(mm);
+ mmput(mm, &mm_ref);

return ret;
}
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 0b859af06b87..4790274af596 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -1374,6 +1374,7 @@ SYSCALL_DEFINE4(migrate_pages, pid_t, pid, unsigned long, maxnode,
{
const struct cred *cred = current_cred(), *tcred;
struct mm_struct *mm = NULL;
+ MM_REF(mm_ref);
struct task_struct *task;
nodemask_t task_nodes;
int err;
@@ -1439,7 +1440,7 @@ SYSCALL_DEFINE4(migrate_pages, pid_t, pid, unsigned long, maxnode,
if (err)
goto out_put;

- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
put_task_struct(task);

if (!mm) {
@@ -1450,7 +1451,7 @@ SYSCALL_DEFINE4(migrate_pages, pid_t, pid, unsigned long, maxnode,
err = do_migrate_pages(mm, old, new,
capable(CAP_SYS_NICE) ? MPOL_MF_MOVE_ALL : MPOL_MF_MOVE);

- mmput(mm);
+ mmput(mm, &mm_ref);
out:
NODEMASK_SCRATCH_FREE(scratch);

diff --git a/mm/migrate.c b/mm/migrate.c
index 99250aee1ac1..593942dabbc1 100644
--- a/mm/migrate.c
+++ b/mm/migrate.c
@@ -1659,6 +1659,7 @@ SYSCALL_DEFINE6(move_pages, pid_t, pid, unsigned long, nr_pages,
const struct cred *cred = current_cred(), *tcred;
struct task_struct *task;
struct mm_struct *mm;
+ MM_REF(mm_ref);
int err;
nodemask_t task_nodes;

@@ -1699,7 +1700,7 @@ SYSCALL_DEFINE6(move_pages, pid_t, pid, unsigned long, nr_pages,
goto out;

task_nodes = cpuset_mems_allowed(task);
- mm = get_task_mm(task);
+ mm = get_task_mm(task, &mm_ref);
put_task_struct(task);

if (!mm)
@@ -1711,7 +1712,7 @@ SYSCALL_DEFINE6(move_pages, pid_t, pid, unsigned long, nr_pages,
else
err = do_pages_stat(mm, nr_pages, pages, status);

- mmput(mm);
+ mmput(mm, &mm_ref);
return err;

out:
diff --git a/mm/mm_ref.c b/mm/mm_ref.c
new file mode 100644
index 000000000000..cf14334aec58
--- /dev/null
+++ b/mm/mm_ref.c
@@ -0,0 +1,163 @@
+#include <linux/list.h>
+#include <linux/mm_ref.h>
+#include <linux/mm_types.h>
+#include <linux/sched.h>
+#include <linux/stacktrace.h>
+
+static void _mm_ref_save_trace(struct mm_ref *ref)
+{
+ ref->pid = current->pid;
+
+ /* Save stack trace */
+ ref->trace.nr_entries = 0;
+ ref->trace.entries = ref->trace_entries;
+ ref->trace.max_entries = NR_MM_REF_STACK_ENTRIES;
+ ref->trace.skip = 1;
+ save_stack_trace(&ref->trace);
+}
+
+void INIT_MM_REF(struct mm_ref *ref)
+{
+ _mm_ref_save_trace(ref);
+ INIT_LIST_HEAD(&ref->list_entry);
+ ref->state = MM_REF_INITIALIZED;
+}
+
+static void dump_refs_list(const char *label, struct list_head *list)
+{
+ struct mm_ref *ref;
+
+ if (list_empty(list)) {
+ printk(KERN_ERR "%s: no refs\n", label);
+ return;
+ }
+
+ printk(KERN_ERR "%s:\n", label);
+ list_for_each_entry(ref, list, list_entry) {
+ printk(KERN_ERR " - %p %x acquired by %d at:%s\n",
+ ref, ref->state,
+ ref->pid,
+ ref->state == MM_REF_ACTIVE ? "" : " (bogus)");
+ if (ref->state == MM_REF_ACTIVE)
+ print_stack_trace(&ref->trace, 2);
+ }
+}
+
+static void dump_refs(struct mm_struct *mm)
+{
+ unsigned long flags;
+
+ spin_lock_irqsave(&mm->mm_refs_lock, flags);
+ printk(KERN_ERR "mm_users = %u\n", atomic_read(&mm->mm_users));
+ dump_refs_list("mm_users_list", &mm->mm_users_list);
+ printk(KERN_ERR "mm_count = %u\n", atomic_read(&mm->mm_count));
+ dump_refs_list("mm_count_list", &mm->mm_count_list);
+ spin_unlock_irqrestore(&mm->mm_refs_lock, flags);
+}
+
+static bool _mm_ref_expect_inactive(struct mm_struct *mm, struct mm_ref *ref)
+{
+ if (ref->state == MM_REF_INITIALIZED || ref->state == MM_REF_INACTIVE)
+ return true;
+
+ if (ref->state == MM_REF_ACTIVE) {
+ printk(KERN_ERR "trying to overwrite active ref %p to mm %p\n", ref, mm);
+ printk(KERN_ERR "previous ref taken by %d at:\n", ref->pid);
+ print_stack_trace(&ref->trace, 0);
+ } else {
+ printk(KERN_ERR "trying to overwrite ref %p in unknown state %x to mm %p\n",
+ ref, ref->state, mm);
+ }
+
+ return false;
+}
+
+static bool _mm_ref_expect_active(struct mm_struct *mm, struct mm_ref *ref)
+{
+ if (ref->state == MM_REF_ACTIVE)
+ return true;
+
+ if (ref->state == MM_REF_INITIALIZED || ref->state == MM_REF_INACTIVE) {
+ printk(KERN_ERR "trying to put inactive ref %p to mm %p\n", ref, mm);
+ if (ref->state == MM_REF_INITIALIZED)
+ printk(KERN_ERR "ref initialized by %d at:\n", ref->pid);
+ else
+ printk(KERN_ERR "previous ref dropped by %d at:\n", ref->pid);
+ print_stack_trace(&ref->trace, 0);
+ } else {
+ printk(KERN_ERR "trying to put ref %p in unknown state %x to mm %p\n",
+ ref, ref->state, mm);
+ }
+
+ return false;
+}
+
+void _get_mm_ref(struct mm_struct *mm, struct list_head *list, struct mm_ref *ref)
+{
+ unsigned long flags;
+
+ if (!_mm_ref_expect_inactive(mm, ref)) {
+ dump_refs(mm);
+ BUG();
+ }
+
+ _mm_ref_save_trace(ref);
+
+ spin_lock_irqsave(&mm->mm_refs_lock, flags);
+ list_add_tail(&ref->list_entry, list);
+ spin_unlock_irqrestore(&mm->mm_refs_lock, flags);
+
+ ref->state = MM_REF_ACTIVE;
+}
+
+void _put_mm_ref(struct mm_struct *mm, struct list_head *list, struct mm_ref *ref)
+{
+ unsigned long flags;
+
+ if (!_mm_ref_expect_active(mm, ref)) {
+ dump_refs(mm);
+ BUG();
+ }
+
+ _mm_ref_save_trace(ref);
+
+ spin_lock_irqsave(&mm->mm_refs_lock, flags);
+ BUG_ON(list_empty(&ref->list_entry));
+ list_del_init(&ref->list_entry);
+ spin_unlock_irqrestore(&mm->mm_refs_lock, flags);
+
+ ref->state = MM_REF_INACTIVE;
+}
+
+/*
+ * TODO: we have a choice here whether to ignore mm == NULL or
+ * treat it as an error.
+ * TODO: there's also a question about whether old_ref == new_ref
+ * is an error or not.
+ */
+void _move_mm_ref(struct mm_struct *mm, struct list_head *list,
+ struct mm_ref *old_ref, struct mm_ref *new_ref)
+{
+ unsigned long flags;
+
+ if (!_mm_ref_expect_active(mm, old_ref)) {
+ dump_refs(mm);
+ BUG();
+ }
+ if (!_mm_ref_expect_inactive(mm, new_ref)) {
+ dump_refs(mm);
+ BUG();
+ }
+
+ _mm_ref_save_trace(old_ref);
+ _mm_ref_save_trace(new_ref);
+
+ spin_lock_irqsave(&mm->mm_refs_lock, flags);
+ BUG_ON(list_empty(&old_ref->list_entry));
+ list_del_init(&old_ref->list_entry);
+ list_add_tail(&new_ref->list_entry, list);
+ spin_unlock_irqrestore(&mm->mm_refs_lock, flags);
+
+ old_ref->state = MM_REF_INACTIVE;
+ new_ref->state = MM_REF_ACTIVE;
+}
diff --git a/mm/mmu_context.c b/mm/mmu_context.c
index daf67bb02b4a..3e28db145982 100644
--- a/mm/mmu_context.c
+++ b/mm/mmu_context.c
@@ -20,12 +20,14 @@
void use_mm(struct mm_struct *mm)
{
struct mm_struct *active_mm;
+ struct mm_ref active_mm_ref;
struct task_struct *tsk = current;

task_lock(tsk);
active_mm = tsk->active_mm;
if (active_mm != mm) {
- mmgrab(mm);
+ move_mm_ref(mm, &tsk->mm_ref, &active_mm_ref);
+ mmgrab(mm, &tsk->mm_ref);
tsk->active_mm = mm;
}
tsk->mm = mm;
@@ -35,8 +37,9 @@ void use_mm(struct mm_struct *mm)
finish_arch_post_lock_switch();
#endif

- if (active_mm != mm)
- mmdrop(active_mm);
+ if (active_mm != mm) {
+ mmdrop(active_mm, &active_mm_ref);
+ }
}
EXPORT_SYMBOL_GPL(use_mm);

diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
index 32bc9f2ff7eb..8187d46c8d05 100644
--- a/mm/mmu_notifier.c
+++ b/mm/mmu_notifier.c
@@ -244,7 +244,7 @@ EXPORT_SYMBOL_GPL(__mmu_notifier_invalidate_range);

static int do_mmu_notifier_register(struct mmu_notifier *mn,
struct mm_struct *mm,
- int take_mmap_sem)
+ int take_mmap_sem, struct mm_ref *mm_ref)
{
struct mmu_notifier_mm *mmu_notifier_mm;
int ret;
@@ -275,7 +275,7 @@ static int do_mmu_notifier_register(struct mmu_notifier *mn,
mm->mmu_notifier_mm = mmu_notifier_mm;
mmu_notifier_mm = NULL;
}
- mmgrab(mm);
+ mmgrab(mm, mm_ref);

/*
* Serialize the update against mmu_notifier_unregister. A
@@ -312,9 +312,9 @@ static int do_mmu_notifier_register(struct mmu_notifier *mn,
* after exit_mmap. ->release will always be called before exit_mmap
* frees the pages.
*/
-int mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
+int mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm, struct mm_ref *mm_ref)
{
- return do_mmu_notifier_register(mn, mm, 1);
+ return do_mmu_notifier_register(mn, mm, 1, mm_ref);
}
EXPORT_SYMBOL_GPL(mmu_notifier_register);

@@ -322,9 +322,9 @@ EXPORT_SYMBOL_GPL(mmu_notifier_register);
* Same as mmu_notifier_register but here the caller must hold the
* mmap_sem in write mode.
*/
-int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
+int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm, struct mm_ref *mm_ref)
{
- return do_mmu_notifier_register(mn, mm, 0);
+ return do_mmu_notifier_register(mn, mm, 0, mm_ref);
}
EXPORT_SYMBOL_GPL(__mmu_notifier_register);

@@ -346,7 +346,7 @@ void __mmu_notifier_mm_destroy(struct mm_struct *mm)
* and only after mmu_notifier_unregister returned we're guaranteed
* that ->release or any other method can't run anymore.
*/
-void mmu_notifier_unregister(struct mmu_notifier *mn, struct mm_struct *mm)
+void mmu_notifier_unregister(struct mmu_notifier *mn, struct mm_struct *mm, struct mm_ref *mm_ref)
{
BUG_ON(atomic_read(&mm->mm_count) <= 0);

@@ -383,7 +383,7 @@ void mmu_notifier_unregister(struct mmu_notifier *mn, struct mm_struct *mm)

BUG_ON(atomic_read(&mm->mm_count) <= 0);

- mmdrop(mm);
+ mmdrop(mm, mm_ref);
}
EXPORT_SYMBOL_GPL(mmu_notifier_unregister);

@@ -391,7 +391,7 @@ EXPORT_SYMBOL_GPL(mmu_notifier_unregister);
* Same as mmu_notifier_unregister but no callback and no srcu synchronization.
*/
void mmu_notifier_unregister_no_release(struct mmu_notifier *mn,
- struct mm_struct *mm)
+ struct mm_struct *mm, struct mm_ref *mm_ref)
{
spin_lock(&mm->mmu_notifier_mm->lock);
/*
@@ -402,7 +402,7 @@ void mmu_notifier_unregister_no_release(struct mmu_notifier *mn,
spin_unlock(&mm->mmu_notifier_mm->lock);

BUG_ON(atomic_read(&mm->mm_count) <= 0);
- mmdrop(mm);
+ mmdrop(mm, mm_ref);
}
EXPORT_SYMBOL_GPL(mmu_notifier_unregister_no_release);

diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index ead093c6f2a6..0aa0b364ec0e 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -463,6 +463,7 @@ static DEFINE_SPINLOCK(oom_reaper_lock);

static bool __oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm)
{
+ MM_REF(mm_ref);
struct mmu_gather tlb;
struct vm_area_struct *vma;
struct zap_details details = {.check_swap_entries = true,
@@ -495,7 +496,7 @@ static bool __oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm)
* that the mmput_async is called only when we have reaped something
* and delayed __mmput doesn't matter that much
*/
- if (!mmget_not_zero(mm)) {
+ if (!mmget_not_zero(mm, &mm_ref)) {
up_read(&mm->mmap_sem);
goto unlock_oom;
}
@@ -547,7 +548,7 @@ static bool __oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm)
* different context because we shouldn't risk we get stuck there and
* put the oom_reaper out of the way.
*/
- mmput_async(mm);
+ mmput_async(mm, &mm_ref);
unlock_oom:
mutex_unlock(&oom_lock);
return ret;
@@ -660,7 +661,7 @@ static void mark_oom_victim(struct task_struct *tsk)

/* oom_mm is bound to the signal struct life time. */
if (!cmpxchg(&tsk->signal->oom_mm, NULL, mm))
- mmgrab(tsk->signal->oom_mm);
+ mmgrab(tsk->signal->oom_mm, &tsk->signal->oom_mm_ref);

/*
* Make sure that the task is woken up from uninterruptible sleep
@@ -812,6 +813,7 @@ static void oom_kill_process(struct oom_control *oc, const char *message)
struct task_struct *child;
struct task_struct *t;
struct mm_struct *mm;
+ MM_REF(mm_ref);
unsigned int victim_points = 0;
static DEFINE_RATELIMIT_STATE(oom_rs, DEFAULT_RATELIMIT_INTERVAL,
DEFAULT_RATELIMIT_BURST);
@@ -877,7 +879,7 @@ static void oom_kill_process(struct oom_control *oc, const char *message)

/* Get a reference to safely compare mm after task_unlock(victim) */
mm = victim->mm;
- mmgrab(mm);
+ mmgrab(mm, &mm_ref);
/*
* We should send SIGKILL before setting TIF_MEMDIE in order to prevent
* the OOM victim from depleting the memory reserves from the user
@@ -928,7 +930,7 @@ static void oom_kill_process(struct oom_control *oc, const char *message)
if (can_oom_reap)
wake_oom_reaper(victim);

- mmdrop(mm);
+ mmdrop(mm, &mm_ref);
put_task_struct(victim);
}
#undef K
diff --git a/mm/process_vm_access.c b/mm/process_vm_access.c
index be8dc8d1edb9..8eef73c5ed81 100644
--- a/mm/process_vm_access.c
+++ b/mm/process_vm_access.c
@@ -155,6 +155,7 @@ static ssize_t process_vm_rw_core(pid_t pid, struct iov_iter *iter,
struct page *pp_stack[PVM_MAX_PP_ARRAY_COUNT];
struct page **process_pages = pp_stack;
struct mm_struct *mm;
+ MM_REF(mm_ref);
unsigned long i;
ssize_t rc = 0;
unsigned long nr_pages = 0;
@@ -202,7 +203,7 @@ static ssize_t process_vm_rw_core(pid_t pid, struct iov_iter *iter,
goto free_proc_pages;
}

- mm = mm_access(task, PTRACE_MODE_ATTACH_REALCREDS);
+ mm = mm_access(task, PTRACE_MODE_ATTACH_REALCREDS, &mm_ref);
if (!mm || IS_ERR(mm)) {
rc = IS_ERR(mm) ? PTR_ERR(mm) : -ESRCH;
/*
@@ -228,7 +229,7 @@ static ssize_t process_vm_rw_core(pid_t pid, struct iov_iter *iter,
if (total_len)
rc = total_len;

- mmput(mm);
+ mmput(mm, &mm_ref);

put_task_struct:
put_task_struct(task);
diff --git a/mm/swapfile.c b/mm/swapfile.c
index 8c92829326cb..781122d8be77 100644
--- a/mm/swapfile.c
+++ b/mm/swapfile.c
@@ -1376,6 +1376,7 @@ int try_to_unuse(unsigned int type, bool frontswap,
{
struct swap_info_struct *si = swap_info[type];
struct mm_struct *start_mm;
+ MM_REF(start_mm_ref);
volatile unsigned char *swap_map; /* swap_map is accessed without
* locking. Mark it as volatile
* to prevent compiler doing
@@ -1402,7 +1403,7 @@ int try_to_unuse(unsigned int type, bool frontswap,
* that.
*/
start_mm = &init_mm;
- mmget(&init_mm);
+ mmget(&init_mm, &start_mm_ref);

/*
* Keep on scanning until all entries have gone. Usually,
@@ -1449,9 +1450,9 @@ int try_to_unuse(unsigned int type, bool frontswap,
* Don't hold on to start_mm if it looks like exiting.
*/
if (atomic_read(&start_mm->mm_users) == 1) {
- mmput(start_mm);
+ mmput(start_mm, &start_mm_ref);
start_mm = &init_mm;
- mmget(&init_mm);
+ mmget(&init_mm, &start_mm_ref);
}

/*
@@ -1485,19 +1486,22 @@ int try_to_unuse(unsigned int type, bool frontswap,
int set_start_mm = (*swap_map >= swcount);
struct list_head *p = &start_mm->mmlist;
struct mm_struct *new_start_mm = start_mm;
+ MM_REF(new_start_mm_ref);
struct mm_struct *prev_mm = start_mm;
+ MM_REF(prev_mm_ref);
struct mm_struct *mm;
+ MM_REF(mm_ref);

- mmget(new_start_mm);
- mmget(prev_mm);
+ mmget(new_start_mm, &new_start_mm_ref);
+ mmget(prev_mm, &prev_mm_ref);
spin_lock(&mmlist_lock);
while (swap_count(*swap_map) && !retval &&
(p = p->next) != &start_mm->mmlist) {
mm = list_entry(p, struct mm_struct, mmlist);
- if (!mmget_not_zero(mm))
+ if (!mmget_not_zero(mm, &mm_ref))
continue;
spin_unlock(&mmlist_lock);
- mmput(prev_mm);
+ mmput(prev_mm, &prev_mm_ref);
prev_mm = mm;

cond_resched();
@@ -1511,17 +1515,18 @@ int try_to_unuse(unsigned int type, bool frontswap,
retval = unuse_mm(mm, entry, page);

if (set_start_mm && *swap_map < swcount) {
- mmput(new_start_mm);
- mmget(mm);
+ mmput(new_start_mm, &new_start_mm_ref);
+ mmget(mm, &mm_ref);
new_start_mm = mm;
set_start_mm = 0;
}
spin_lock(&mmlist_lock);
}
spin_unlock(&mmlist_lock);
- mmput(prev_mm);
- mmput(start_mm);
+ mmput(prev_mm, &prev_mm_ref);
+ mmput(start_mm, &start_mm_ref);
start_mm = new_start_mm;
+ move_mm_users_ref(start_mm, &new_start_mm_ref, &start_mm_ref);
}
if (retval) {
unlock_page(page);
@@ -1590,7 +1595,7 @@ int try_to_unuse(unsigned int type, bool frontswap,
}
}

- mmput(start_mm);
+ mmput(start_mm, &start_mm_ref);
return retval;
}

diff --git a/mm/util.c b/mm/util.c
index 1a41553db866..9bace6820707 100644
--- a/mm/util.c
+++ b/mm/util.c
@@ -607,7 +607,8 @@ int get_cmdline(struct task_struct *task, char *buffer, int buflen)
{
int res = 0;
unsigned int len;
- struct mm_struct *mm = get_task_mm(task);
+ MM_REF(mm_ref);
+ struct mm_struct *mm = get_task_mm(task, &mm_ref);
unsigned long arg_start, arg_end, env_start, env_end;
if (!mm)
goto out;
@@ -647,7 +648,7 @@ int get_cmdline(struct task_struct *task, char *buffer, int buflen)
}
}
out_mm:
- mmput(mm);
+ mmput(mm, &mm_ref);
out:
return res;
}
diff --git a/virt/kvm/async_pf.c b/virt/kvm/async_pf.c
index 9ec9cef2b207..972084e84bd6 100644
--- a/virt/kvm/async_pf.c
+++ b/virt/kvm/async_pf.c
@@ -108,7 +108,7 @@ static void async_pf_execute(struct work_struct *work)
if (swait_active(&vcpu->wq))
swake_up(&vcpu->wq);

- mmput(mm);
+ mmput(mm, &apf->mm_ref);
kvm_put_kvm(vcpu->kvm);
}

@@ -135,7 +135,7 @@ void kvm_clear_async_pf_completion_queue(struct kvm_vcpu *vcpu)
flush_work(&work->work);
#else
if (cancel_work_sync(&work->work)) {
- mmput(work->mm);
+ mmput(work->mm, &work->mm_ref);
kvm_put_kvm(vcpu->kvm); /* == work->vcpu->kvm */
kmem_cache_free(async_pf_cache, work);
}
@@ -200,7 +200,8 @@ int kvm_setup_async_pf(struct kvm_vcpu *vcpu, gva_t gva, unsigned long hva,
work->addr = hva;
work->arch = *arch;
work->mm = current->mm;
- mmget(work->mm);
+ INIT_MM_REF(&work->mm_ref);
+ mmget(work->mm, &work->mm_ref);
kvm_get_kvm(work->vcpu->kvm);

/* this can't really happen otherwise gfn_to_pfn_async
@@ -218,7 +219,7 @@ int kvm_setup_async_pf(struct kvm_vcpu *vcpu, gva_t gva, unsigned long hva,
return 1;
retry_sync:
kvm_put_kvm(work->vcpu->kvm);
- mmput(work->mm);
+ mmput(work->mm, &work->mm_ref);
kmem_cache_free(async_pf_cache, work);
return 0;
}
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 43914b981691..d608457033d5 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -482,7 +482,8 @@ static const struct mmu_notifier_ops kvm_mmu_notifier_ops = {
static int kvm_init_mmu_notifier(struct kvm *kvm)
{
kvm->mmu_notifier.ops = &kvm_mmu_notifier_ops;
- return mmu_notifier_register(&kvm->mmu_notifier, current->mm);
+ return mmu_notifier_register(&kvm->mmu_notifier, current->mm,
+ &kvm->mmu_notifier_ref);
}

#else /* !(CONFIG_MMU_NOTIFIER && KVM_ARCH_WANT_MMU_NOTIFIER) */
@@ -608,12 +609,13 @@ static struct kvm *kvm_create_vm(unsigned long type)
{
int r, i;
struct kvm *kvm = kvm_arch_alloc_vm();
+ MM_REF(mm_ref);

if (!kvm)
return ERR_PTR(-ENOMEM);

spin_lock_init(&kvm->mmu_lock);
- mmgrab(current->mm);
+ mmgrab(current->mm, &kvm->mm_ref);
kvm->mm = current->mm;
kvm_eventfd_init(kvm);
mutex_init(&kvm->lock);
@@ -654,6 +656,7 @@ static struct kvm *kvm_create_vm(unsigned long type)
goto out_err;
}

+ INIT_MM_REF(&kvm->mmu_notifier_ref);
r = kvm_init_mmu_notifier(kvm);
if (r)
goto out_err;
@@ -677,8 +680,9 @@ static struct kvm *kvm_create_vm(unsigned long type)
kfree(kvm->buses[i]);
for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++)
kvm_free_memslots(kvm, kvm->memslots[i]);
+ move_mm_ref(kvm->mm, &kvm->mm_ref, &mm_ref);
kvm_arch_free_vm(kvm);
- mmdrop(current->mm);
+ mmdrop(current->mm, &mm_ref);
return ERR_PTR(r);
}

@@ -713,6 +717,7 @@ static void kvm_destroy_vm(struct kvm *kvm)
{
int i;
struct mm_struct *mm = kvm->mm;
+ MM_REF(mm_ref);

kvm_destroy_vm_debugfs(kvm);
kvm_arch_sync_events(kvm);
@@ -724,7 +729,7 @@ static void kvm_destroy_vm(struct kvm *kvm)
kvm_io_bus_destroy(kvm->buses[i]);
kvm_coalesced_mmio_free(kvm);
#if defined(CONFIG_MMU_NOTIFIER) && defined(KVM_ARCH_WANT_MMU_NOTIFIER)
- mmu_notifier_unregister(&kvm->mmu_notifier, kvm->mm);
+ mmu_notifier_unregister(&kvm->mmu_notifier, kvm->mm, &kvm->mmu_notifier_ref);
#else
kvm_arch_flush_shadow_all(kvm);
#endif
@@ -734,10 +739,11 @@ static void kvm_destroy_vm(struct kvm *kvm)
kvm_free_memslots(kvm, kvm->memslots[i]);
cleanup_srcu_struct(&kvm->irq_srcu);
cleanup_srcu_struct(&kvm->srcu);
+ move_mm_ref(mm, &kvm->mm_ref, &mm_ref);
kvm_arch_free_vm(kvm);
preempt_notifier_dec();
hardware_disable_all();
- mmdrop(mm);
+ mmdrop(mm, &mm_ref);
}

void kvm_get_kvm(struct kvm *kvm)
--
2.11.0.1.gaa10c3f