[RFC PATCH 03/11] mm/mempolicy: refactor set_mempolicy stack to take a task argument

From: Gregory Price
Date: Wed Nov 22 2023 - 16:12:30 EST


To make mempolicy modifiable by external tasks, we must refactor
the callstack to take a task as an argument.

Modify the following functions to require a task argument:
mpol_set_nodemask
replace_mempolicy
do_set_mempolicy

Since replace_mempolicy already acquired the task lock, there
is no need to change any locking behaviors.

All other callers (as of this patch) to mpol_set_nodemask
call either in the context of current with the task or mmap
lock held, so no other changes are required.

Signed-off-by: Gregory Price <gregory.price@xxxxxxxxxxxx>
---
mm/mempolicy.c | 51 +++++++++++++++++++++++++++-----------------------
1 file changed, 28 insertions(+), 23 deletions(-)

diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 37da712259d7..9ea3e1bfc002 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -226,8 +226,10 @@ static int mpol_new_preferred(struct mempolicy *pol, const nodemask_t *nodes)
* Must be called holding task's alloc_lock to protect task's mems_allowed
* and mempolicy. May also be called holding the mmap_lock for write.
*/
-static int mpol_set_nodemask(struct mempolicy *pol,
- const nodemask_t *nodes, struct nodemask_scratch *nsc)
+static int mpol_set_nodemask(struct task_struct *tsk,
+ struct mempolicy *pol,
+ const nodemask_t *nodes,
+ struct nodemask_scratch *nsc)
{
int ret;

@@ -240,8 +242,7 @@ static int mpol_set_nodemask(struct mempolicy *pol,
return 0;

/* Check N_MEMORY */
- nodes_and(nsc->mask1,
- cpuset_current_mems_allowed, node_states[N_MEMORY]);
+ nodes_and(nsc->mask1, tsk->mems_allowed, node_states[N_MEMORY]);

VM_BUG_ON(!nodes);

@@ -253,7 +254,7 @@ static int mpol_set_nodemask(struct mempolicy *pol,
if (mpol_store_user_nodemask(pol))
pol->w.user_nodemask = *nodes;
else
- pol->w.cpuset_mems_allowed = cpuset_current_mems_allowed;
+ pol->w.cpuset_mems_allowed = tsk->mems_allowed;

ret = mpol_ops[pol->mode].create(pol, &nsc->mask2);
return ret;
@@ -810,7 +811,9 @@ static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
}

/* Attempt to replace mempolicy, release the old one if successful */
-static long replace_mempolicy(struct mempolicy *new, nodemask_t *nodes)
+static long replace_mempolicy(struct task_struct *task,
+ struct mempolicy *new,
+ nodemask_t *nodes)
{
struct mempolicy *old = NULL;
NODEMASK_SCRATCH(scratch);
@@ -819,19 +822,19 @@ static long replace_mempolicy(struct mempolicy *new, nodemask_t *nodes)
if (!scratch)
return -ENOMEM;

- task_lock(current);
- ret = mpol_set_nodemask(new, nodes, scratch);
+ task_lock(task);
+ ret = mpol_set_nodemask(task, new, nodes, scratch);
if (ret) {
- task_unlock(current);
+ task_unlock(task);
goto out;
}

- old = current->mempolicy;
- current->mempolicy = new;
+ old = task->mempolicy;
+ task->mempolicy = new;
if (new && new->mode == MPOL_INTERLEAVE)
- current->il_prev = MAX_NUMNODES-1;
+ task->il_prev = MAX_NUMNODES-1;
out:
- task_unlock(current);
+ task_unlock(task);
mpol_put(old);

NODEMASK_SCRATCH_FREE(scratch);
@@ -839,8 +842,8 @@ static long replace_mempolicy(struct mempolicy *new, nodemask_t *nodes)
}

/* Set the process memory policy */
-static long do_set_mempolicy(unsigned short mode, unsigned short flags,
- nodemask_t *nodes)
+static long do_set_mempolicy(struct task_struct *task, unsigned short mode,
+ unsigned short flags, nodemask_t *nodes)
{
struct mempolicy *new;
int ret;
@@ -849,7 +852,7 @@ static long do_set_mempolicy(unsigned short mode, unsigned short flags,
if (IS_ERR(new))
return PTR_ERR(new);

- ret = replace_mempolicy(new, nodes);
+ ret = replace_mempolicy(task, new, nodes);
if (ret)
mpol_put(new);

@@ -1284,7 +1287,7 @@ static long do_mbind(unsigned long start, unsigned long len,
NODEMASK_SCRATCH(scratch);
if (scratch) {
mmap_write_lock(mm);
- err = mpol_set_nodemask(new, nmask, scratch);
+ err = mpol_set_nodemask(current, new, nmask, scratch);
if (err)
mmap_write_unlock(mm);
} else
@@ -1580,7 +1583,8 @@ SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len,
}

/* Set the process memory policy */
-static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
+static long kernel_set_mempolicy(struct task_struct *task, int mode,
+ const unsigned long __user *nmask,
unsigned long maxnode)
{
unsigned short mode_flags;
@@ -1596,13 +1600,13 @@ static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
if (err)
return err;

- return do_set_mempolicy(lmode, mode_flags, &nodes);
+ return do_set_mempolicy(task, lmode, mode_flags, &nodes);
}

SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
unsigned long, maxnode)
{
- return kernel_set_mempolicy(mode, nmask, maxnode);
+ return kernel_set_mempolicy(current, mode, nmask, maxnode);
}

static int kernel_migrate_pages(pid_t pid, unsigned long maxnode,
@@ -2722,7 +2726,8 @@ void mpol_shared_policy_init(struct shared_policy *sp, struct mempolicy *mpol)
goto free_scratch; /* no valid nodemask intersection */

task_lock(current);
- ret = mpol_set_nodemask(npol, &mpol->w.user_nodemask, scratch);
+ ret = mpol_set_nodemask(current, npol, &mpol->w.user_nodemask,
+ scratch);
task_unlock(current);
if (ret)
goto put_npol;
@@ -2870,7 +2875,7 @@ void __init numa_policy_init(void)
if (unlikely(nodes_empty(interleave_nodes)))
node_set(prefer, interleave_nodes);

- if (do_set_mempolicy(MPOL_INTERLEAVE, 0, &interleave_nodes))
+ if (do_set_mempolicy(current, MPOL_INTERLEAVE, 0, &interleave_nodes))
pr_err("%s: interleaving failed\n", __func__);

check_numabalancing_enable();
@@ -2879,7 +2884,7 @@ void __init numa_policy_init(void)
/* Reset policy of current process to default */
void numa_default_policy(void)
{
- do_set_mempolicy(MPOL_DEFAULT, 0, NULL);
+ do_set_mempolicy(current, MPOL_DEFAULT, 0, NULL);
}

/*
--
2.39.1