Re: [PATCH v1 2/4] mm/mempolicy: unify the preprocessing for mbind and set_mempolicy

From: Michal Hocko
Date: Thu May 27 2021 - 03:39:53 EST


On Wed 26-05-21 13:01:40, Feng Tang wrote:
> Currently the kernel_mbind() and kernel_set_mempolicy() do almost
> the same operation for parameter sanity check and preprocessing.
>
> Add a helper function to unify the code to reduce the redundancy,
> and make it easier for changing the pre-processing code in future.
>
> [thanks to David Rientjes for suggesting using helper function
> instead of macro]

I appreciate removing the code duplication but I am not really convinced
this is an improvement. You are conflating two things. One is the mpol
flags checking and node mask copying. While abstracting the first one
makes sense to me the later is already a single line of code that makes
your helper unnecessarily complex. So I would go with sanitize_mpol_flags
and put a flags handling there and leave get_nodes alone.

> Signed-off-by: Feng Tang <feng.tang@xxxxxxxxx>
> ---
> mm/mempolicy.c | 43 ++++++++++++++++++++++++-------------------
> 1 file changed, 24 insertions(+), 19 deletions(-)

Funny how removing code duplication adds more code than it removes ;)

>
> diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> index 1964cca..2830bb8 100644
> --- a/mm/mempolicy.c
> +++ b/mm/mempolicy.c
> @@ -1460,6 +1460,20 @@ static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode,
> return copy_to_user(mask, nodes_addr(*nodes), copy) ? -EFAULT : 0;
> }
>
> +static inline int mpol_pre_process(int *mode, const unsigned long __user *nmask, unsigned long maxnode, nodemask_t *nodes, unsigned short *flags)
> +{
> + int ret;
> +
> + *flags = *mode & MPOL_MODE_FLAGS;
> + *mode &= ~MPOL_MODE_FLAGS;
> + if ((unsigned int)(*mode) >= MPOL_MAX)
> + return -EINVAL;
> + if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
> + return -EINVAL;
> + ret = get_nodes(nodes, nmask, maxnode);
> + return ret;
> +}
> +
> static long kernel_mbind(unsigned long start, unsigned long len,
> unsigned long mode, const unsigned long __user *nmask,
> unsigned long maxnode, unsigned int flags)
> @@ -1467,19 +1481,14 @@ static long kernel_mbind(unsigned long start, unsigned long len,
> nodemask_t nodes;
> int err;
> unsigned short mode_flags;
> + int lmode = mode;
>
> - start = untagged_addr(start);
> - mode_flags = mode & MPOL_MODE_FLAGS;
> - mode &= ~MPOL_MODE_FLAGS;
> - if (mode >= MPOL_MAX)
> - return -EINVAL;
> - if ((mode_flags & MPOL_F_STATIC_NODES) &&
> - (mode_flags & MPOL_F_RELATIVE_NODES))
> - return -EINVAL;
> - err = get_nodes(&nodes, nmask, maxnode);
> + err = mpol_pre_process(&lmode, nmask, maxnode, &nodes, &mode_flags);
> if (err)
> return err;
> - return do_mbind(start, len, mode, mode_flags, &nodes, flags);
> +
> + start = untagged_addr(start);
> + return do_mbind(start, len, lmode, mode_flags, &nodes, flags);
> }
>
> SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len,
> @@ -1495,18 +1504,14 @@ static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
> {
> int err;
> nodemask_t nodes;
> - unsigned short flags;
> + unsigned short mode_flags;
> + int lmode = mode;
>
> - flags = mode & MPOL_MODE_FLAGS;
> - mode &= ~MPOL_MODE_FLAGS;
> - if ((unsigned int)mode >= MPOL_MAX)
> - return -EINVAL;
> - if ((flags & MPOL_F_STATIC_NODES) && (flags & MPOL_F_RELATIVE_NODES))
> - return -EINVAL;
> - err = get_nodes(&nodes, nmask, maxnode);
> + err = mpol_pre_process(&lmode, nmask, maxnode, &nodes, &mode_flags);
> if (err)
> return err;
> - return do_set_mempolicy(mode, flags, &nodes);
> +
> + return do_set_mempolicy(lmode, mode_flags, &nodes);
> }
>
> SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
> --
> 2.7.4

--
Michal Hocko
SUSE Labs