[RFC PATCH 2/3] mm/mempolicy: Implement set_mempolicy2 and get_mempolicy2 syscalls

From: Gregory Price
Date: Thu Sep 14 2023 - 19:55:16 EST


sys_set_mempolicy is limited by its current argument structure
(mode, nodes, flags) to implementing policies that can be described
in that manner.

Implement set/get_mempolicy2 with a new mempolicy_args structure
which encapsulates the old behavior, and allows for new mempolicies
which may require additional information.

Signed-off-by: Gregory Price <gregory.price@xxxxxxxxxxxx>
---
arch/x86/entry/syscalls/syscall_32.tbl | 2 +
arch/x86/entry/syscalls/syscall_64.tbl | 2 +
include/linux/syscalls.h | 2 +
include/uapi/asm-generic/unistd.h | 10 +-
include/uapi/linux/mempolicy.h | 32 ++++
mm/mempolicy.c | 215 ++++++++++++++++++++++++-
6 files changed, 261 insertions(+), 2 deletions(-)

diff --git a/arch/x86/entry/syscalls/syscall_32.tbl b/arch/x86/entry/syscalls/syscall_32.tbl
index 2d0b1bd866ea..a72ef588a704 100644
--- a/arch/x86/entry/syscalls/syscall_32.tbl
+++ b/arch/x86/entry/syscalls/syscall_32.tbl
@@ -457,3 +457,5 @@
450 i386 set_mempolicy_home_node sys_set_mempolicy_home_node
451 i386 cachestat sys_cachestat
452 i386 fchmodat2 sys_fchmodat2
+454 i386 set_mempolicy2 sys_set_mempolicy2
+455 i386 get_mempolicy2 sys_get_mempolicy2
diff --git a/arch/x86/entry/syscalls/syscall_64.tbl b/arch/x86/entry/syscalls/syscall_64.tbl
index 1d6eee30eceb..ec54064de8b3 100644
--- a/arch/x86/entry/syscalls/syscall_64.tbl
+++ b/arch/x86/entry/syscalls/syscall_64.tbl
@@ -375,6 +375,8 @@
451 common cachestat sys_cachestat
452 common fchmodat2 sys_fchmodat2
453 64 map_shadow_stack sys_map_shadow_stack
+454 common set_mempolicy2 sys_set_mempolicy2
+455 common get_mempolicy2 sys_get_mempolicy2

#
# Due to a historical design error, certain syscalls are numbered differently
diff --git a/include/linux/syscalls.h b/include/linux/syscalls.h
index 22bc6bc147f8..d50a452954ae 100644
--- a/include/linux/syscalls.h
+++ b/include/linux/syscalls.h
@@ -813,6 +813,8 @@ asmlinkage long sys_get_mempolicy(int __user *policy,
unsigned long addr, unsigned long flags);
asmlinkage long sys_set_mempolicy(int mode, const unsigned long __user *nmask,
unsigned long maxnode);
+asmlinkage long sys_get_mempolicy2(struct mempolicy_args __user *args);
+asmlinkage long sys_set_mempolicy2(struct mempolicy_args __user *args);
asmlinkage long sys_migrate_pages(pid_t pid, unsigned long maxnode,
const unsigned long __user *from,
const unsigned long __user *to);
diff --git a/include/uapi/asm-generic/unistd.h b/include/uapi/asm-generic/unistd.h
index abe087c53b4b..397dcf804941 100644
--- a/include/uapi/asm-generic/unistd.h
+++ b/include/uapi/asm-generic/unistd.h
@@ -823,8 +823,16 @@ __SYSCALL(__NR_cachestat, sys_cachestat)
#define __NR_fchmodat2 452
__SYSCALL(__NR_fchmodat2, sys_fchmodat2)

+/* CONFIG_MMU only */
+#ifndef __ARCH_NOMMU
+#define __NR_set_mempolicy 454
+__SYSCALL(__NR_set_mempolicy2, sys_set_mempolicy2)
+#define __NR_set_mempolicy 455
+__SYSCALL(__NR_get_mempolicy2, sys_get_mempolicy2)
+#endif
+
#undef __NR_syscalls
-#define __NR_syscalls 453
+#define __NR_syscalls 456

/*
* 32 bit systems traditionally used different
diff --git a/include/uapi/linux/mempolicy.h b/include/uapi/linux/mempolicy.h
index 046d0ccba4cd..53650f69db2b 100644
--- a/include/uapi/linux/mempolicy.h
+++ b/include/uapi/linux/mempolicy.h
@@ -23,9 +23,41 @@ enum {
MPOL_INTERLEAVE,
MPOL_LOCAL,
MPOL_PREFERRED_MANY,
+ MPOL_LEGACY, /* set_mempolicy limited to above modes */
MPOL_MAX, /* always last member of enum */
};

+struct mempolicy_args {
+ int err;
+ unsigned short mode;
+ unsigned long *nodemask;
+ unsigned long maxnode;
+ unsigned short flags;
+ struct {
+ /* Memory allowed */
+ struct {
+ int err;
+ unsigned long maxnode;
+ unsigned long *nodemask;
+ } allowed;
+ /* Address information */
+ struct {
+ int err;
+ unsigned long addr;
+ unsigned long node;
+ unsigned short mode;
+ unsigned short flags;
+ } addr;
+ /* Interleave */
+ } get;
+ /* Mode specific settings */
+ union {
+ struct {
+ unsigned long next_node; /* get only */
+ } interleave;
+ };
+};
+
/* Flags for set_mempolicy */
#define MPOL_F_STATIC_NODES (1 << 15)
#define MPOL_F_RELATIVE_NODES (1 << 14)
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index f49337f6f300..1cf7709400f1 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -1483,7 +1483,7 @@ static inline int sanitize_mpol_flags(int *mode, unsigned short *flags)
*flags = *mode & MPOL_MODE_FLAGS;
*mode &= ~MPOL_MODE_FLAGS;

- if ((unsigned int)(*mode) >= MPOL_MAX)
+ if ((unsigned int)(*mode) >= MPOL_LEGACY)
return -EINVAL;
if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
return -EINVAL;
@@ -1614,6 +1614,219 @@ SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
return kernel_set_mempolicy(mode, nmask, maxnode);
}

+static long do_set_mempolicy2(struct mempolicy_args *args)
+{
+ struct mempolicy *new = NULL;
+ nodemask_t nodes;
+ int err;
+
+ if (args->mode <= MPOL_LEGACY)
+ return -EINVAL;
+
+ if (args->mode >= MPOL_MAX)
+ return -EINVAL;
+
+ err = get_nodes(&nodes, args->nodemask, args->maxnode);
+ if (err)
+ return err;
+
+ new = mpol_new(args->mode, args->flags, &nodes);
+ if (IS_ERR(new)) {
+ err = PTR_ERR(new);
+ goto out;
+ }
+
+ switch (args->mode) {
+ default:
+ BUG();
+ }
+
+ if (err)
+ goto out;
+
+ err = swap_mempolicy(new, &nodes);
+out:
+ if (err && new)
+ mpol_put(new);
+ return err;
+};
+
+static bool mempolicy2_args_valid(struct mempolicy_args *kargs)
+{
+ /* Legacy modes are routed through the legacy interface */
+ if (kargs->mode <= MPOL_LEGACY)
+ return false;
+
+ if (kargs->mode >= MPOL_MAX)
+ return false;
+
+ return true;
+}
+
+static long kernel_set_mempolicy2(const struct mempolicy_args __user *uargs,
+ size_t usize)
+{
+ struct mempolicy_args kargs;
+ int err;
+
+ if (usize != sizeof(kargs))
+ return -EINVAL;
+
+ err = copy_struct_from_user(&kargs, sizeof(kargs), uargs, usize);
+ if (err)
+ return err;
+
+ /* If the mode is legacy, use the legacy path */
+ if (kargs.mode < MPOL_LEGACY) {
+ int legacy_mode = kargs.mode | kargs.flags;
+ const unsigned long __user *lnmask = kargs.nodemask;
+ unsigned long maxnode = kargs.maxnode;
+
+ return kernel_set_mempolicy(legacy_mode, lnmask, maxnode);
+ }
+
+ if (!mempolicy2_args_valid(&kargs))
+ return -EINVAL;
+
+ return do_set_mempolicy2(&kargs);
+}
+
+SYSCALL_DEFINE2(set_mempolicy2, const struct mempolicy_args __user *, args,
+ size_t, size)
+{
+ return kernel_set_mempolicy2(args, size);
+}
+
+/* Gets extended mempolicy information */
+static long do_get_mempolicy2(struct mempolicy_args *kargs)
+{
+ struct mempolicy *pol = current->mempolicy;
+ nodemask_t knodes;
+ int err = 0;
+
+ kargs->err = 0;
+ kargs->mode = pol->mode;
+ /* Mask off internal flags */
+ kargs->flags = (pol->flags & MPOL_MODE_FLAGS);
+
+ if (kargs->nodemask) {
+ if (mpol_store_user_nodemask(pol)) {
+ knodes = pol->w.user_nodemask;
+ } else {
+ task_lock(current);
+ get_policy_nodemask(pol, &knodes);
+ task_unlock(current);
+ }
+ err = copy_nodes_to_user(kargs->nodemask,
+ kargs->maxnode,
+ &knodes);
+ if (err)
+ return -EINVAL;
+ }
+
+
+ if (kargs->get.allowed.nodemask) {
+ kargs->get.allowed.err = 0;
+ task_lock(current);
+ knodes = cpuset_current_mems_allowed;
+ task_unlock(current);
+ err = copy_nodes_to_user(kargs->get.allowed.nodemask,
+ kargs->get.allowed.maxnode,
+ &knodes);
+ kargs->get.allowed.err = err ? err : 0;
+ kargs->err |= err ? err : 1;
+ }
+
+ if (kargs->get.addr.addr) {
+ struct mempolicy *addr_pol = NULL;
+ struct vm_area_struct *vma = NULL;
+ struct mm_struct *mm = current->mm;
+ unsigned long addr = kargs->get.addr.addr;
+
+ kargs->get.addr.err = 0;
+
+ /*
+ * Do NOT fall back to task policy if the
+ * vma/shared policy at addr is NULL. We
+ * want to return MPOL_DEFAULT in this case.
+ */
+ mmap_read_lock(mm);
+ vma = vma_lookup(mm, addr);
+ if (!vma) {
+ mmap_read_unlock(mm);
+ kargs->get.addr.err = -EFAULT;
+ kargs->err |= err ? err : 2;
+ goto mode_info;
+ }
+ if (vma->vm_ops && vma->vm_ops->get_policy)
+ addr_pol = vma->vm_ops->get_policy(vma, addr);
+ else
+ addr_pol = vma->vm_policy;
+
+ kargs->get.addr.mode = addr_pol->mode;
+ /* Mask off internal flags */
+ kargs->get.addr.flags = (pol->flags & MPOL_MODE_FLAGS);
+
+ /*
+ * Take a refcount on the mpol, because we are about to
+ * drop the mmap_lock, after which only "pol" remains
+ * valid, "vma" is stale.
+ */
+ vma = NULL;
+ mpol_get(addr_pol);
+ mmap_read_unlock(mm);
+ err = lookup_node(mm, addr);
+ mpol_put(addr_pol);
+ if (err < 0) {
+ kargs->get.addr.err = err;
+ kargs->err |= err ? err : 4;
+ goto mode_info;
+ }
+ kargs->get.addr.node = err;
+ }
+
+mode_info:
+ switch (kargs->mode) {
+ case MPOL_INTERLEAVE:
+ kargs->interleave.next_node = next_node_in(current->il_prev,
+ pol->nodes);
+ break;
+ default:
+ break;
+ }
+
+ return err;
+}
+
+static long kernel_get_mempolicy2(struct mempolicy_args __user *uargs,
+ size_t usize)
+{
+ struct mempolicy_args kargs;
+ int err;
+
+ if (usize != sizeof(struct mempolicy_args))
+ return -EINVAL;
+
+ err = copy_struct_from_user(&kargs, sizeof(kargs), uargs, usize);
+ if (err)
+ return err;
+
+ /* Get the extended memory policy information (kargs.ext) */
+ err = do_get_mempolicy2(&kargs);
+ if (err)
+ return err;
+
+ err = copy_to_user(uargs, &kargs, sizeof(struct mempolicy_args));
+
+ return err;
+}
+
+SYSCALL_DEFINE2(get_mempolicy2, struct mempolicy_args __user *, policy,
+ size_t, size)
+{
+ return kernel_get_mempolicy2(policy, size);
+}
+
static int kernel_migrate_pages(pid_t pid, unsigned long maxnode,
const unsigned long __user *old_nodes,
const unsigned long __user *new_nodes)
--
2.39.1