Re: [PATCH v2 2/5] mm: abstract the vma_merge()/split_vma() pattern for mprotect() et al.

From: Liam R. Howlett
Date: Tue Oct 10 2023 - 22:15:18 EST


* Lorenzo Stoakes <lstoakes@xxxxxxxxx> [231009 16:53]:
> mprotect() and other functions which change VMA parameters over a range
> each employ a pattern of:-
>
> 1. Attempt to merge the range with adjacent VMAs.
> 2. If this fails, and the range spans a subset of the VMA, split it
> accordingly.
>
> This is open-coded and duplicated in each case. Also in each case most of
> the parameters passed to vma_merge() remain the same.
>
> Create a new function, vma_modify(), which abstracts this operation,
> accepting only those parameters which can be changed.
>
> To avoid the mess of invoking each function call with unnecessary
> parameters, create inline wrapper functions for each of the modify
> operations, parameterised only by what is required to perform the action.
>
> Note that the userfaultfd_release() case works even though it does not
> split VMAs - since start is set to vma->vm_start and end is set to
> vma->vm_end, the split logic does not trigger.
>
> In addition, since we calculate pgoff to be equal to vma->vm_pgoff + (start
> - vma->vm_start) >> PAGE_SHIFT, and start - vma->vm_start will be 0 in this
> instance, this invocation will remain unchanged.
>
> Signed-off-by: Lorenzo Stoakes <lstoakes@xxxxxxxxx>
> ---
> fs/userfaultfd.c | 69 +++++++++++++++-------------------------------
> include/linux/mm.h | 60 ++++++++++++++++++++++++++++++++++++++++
> mm/madvise.c | 32 ++++++---------------
> mm/mempolicy.c | 22 +++------------
> mm/mlock.c | 27 +++++-------------
> mm/mmap.c | 45 ++++++++++++++++++++++++++++++
> mm/mprotect.c | 35 +++++++----------------
> 7 files changed, 157 insertions(+), 133 deletions(-)
>
> diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> index a7c6ef764e63..ba44a67a0a34 100644
> --- a/fs/userfaultfd.c
> +++ b/fs/userfaultfd.c
> @@ -927,11 +927,10 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
> continue;
> }
> new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
> - prev = vma_merge(&vmi, mm, prev, vma->vm_start, vma->vm_end,
> - new_flags, vma->anon_vma,
> - vma->vm_file, vma->vm_pgoff,
> - vma_policy(vma),
> - NULL_VM_UFFD_CTX, anon_vma_name(vma));
> + prev = vma_modify_flags_uffd(&vmi, prev, vma, vma->vm_start,
> + vma->vm_end, new_flags,
> + NULL_VM_UFFD_CTX);
> +
> if (prev) {
> vma = prev;
> } else {
> @@ -1331,7 +1330,6 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
> unsigned long start, end, vma_end;
> struct vma_iterator vmi;
> bool wp_async = userfaultfd_wp_async_ctx(ctx);
> - pgoff_t pgoff;
>
> user_uffdio_register = (struct uffdio_register __user *) arg;
>
> @@ -1484,28 +1482,17 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
> vma_end = min(end, vma->vm_end);
>
> new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
> - pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> - prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
> - vma->anon_vma, vma->vm_file, pgoff,
> - vma_policy(vma),
> - ((struct vm_userfaultfd_ctx){ ctx }),
> - anon_vma_name(vma));
> - if (prev) {
> - /* vma_merge() invalidated the mas */
> - vma = prev;
> - goto next;
> - }
> - if (vma->vm_start < start) {
> - ret = split_vma(&vmi, vma, start, 1);
> - if (ret)
> - break;
> - }
> - if (vma->vm_end > end) {
> - ret = split_vma(&vmi, vma, end, 0);
> - if (ret)
> - break;
> + prev = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
> + new_flags,
> + (struct vm_userfaultfd_ctx){ctx});
> + if (IS_ERR(prev)) {
> + ret = PTR_ERR(prev);
> + break;
> }
> - next:
> +
> + if (prev)
> + vma = prev; /* vma_merge() invalidated the mas */

This is a stale comment. The maple state is in the vma iterator, which
is passed through. I missed this on the vma iterator conversion.

> +
> /*
> * In the vma_merge() successful mprotect-like case 8:
> * the next vma was merged into the current one and
> @@ -1568,7 +1555,6 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
> const void __user *buf = (void __user *)arg;
> struct vma_iterator vmi;
> bool wp_async = userfaultfd_wp_async_ctx(ctx);
> - pgoff_t pgoff;
>
> ret = -EFAULT;
> if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
> @@ -1671,26 +1657,15 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
> uffd_wp_range(vma, start, vma_end - start, false);
>
> new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
> - pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> - prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
> - vma->anon_vma, vma->vm_file, pgoff,
> - vma_policy(vma),
> - NULL_VM_UFFD_CTX, anon_vma_name(vma));
> - if (prev) {
> - vma = prev;
> - goto next;
> - }
> - if (vma->vm_start < start) {
> - ret = split_vma(&vmi, vma, start, 1);
> - if (ret)
> - break;
> - }
> - if (vma->vm_end > end) {
> - ret = split_vma(&vmi, vma, end, 0);
> - if (ret)
> - break;
> + prev = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
> + new_flags, NULL_VM_UFFD_CTX);
> + if (IS_ERR(prev)) {
> + ret = PTR_ERR(prev);
> + break;
> }
> - next:
> +
> + if (prev)
> + vma = prev;
> /*
> * In the vma_merge() successful mprotect-like case 8:
> * the next vma was merged into the current one and
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index a7b667786cde..83ee1f35febe 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -3253,6 +3253,66 @@ extern struct vm_area_struct *copy_vma(struct vm_area_struct **,
> unsigned long addr, unsigned long len, pgoff_t pgoff,
> bool *need_rmap_locks);
> extern void exit_mmap(struct mm_struct *);
> +struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
> + struct vm_area_struct *prev,
> + struct vm_area_struct *vma,
> + unsigned long start, unsigned long end,
> + unsigned long vm_flags,
> + struct mempolicy *policy,
> + struct vm_userfaultfd_ctx uffd_ctx,
> + struct anon_vma_name *anon_name);
> +
> +/* We are about to modify the VMA's flags. */
> +static inline struct vm_area_struct
> +*vma_modify_flags(struct vma_iterator *vmi,
> + struct vm_area_struct *prev,
> + struct vm_area_struct *vma,
> + unsigned long start, unsigned long end,
> + unsigned long new_flags)
> +{
> + return vma_modify(vmi, prev, vma, start, end, new_flags,
> + vma_policy(vma), vma->vm_userfaultfd_ctx,
> + anon_vma_name(vma));
> +}
> +
> +/* We are about to modify the VMA's flags and/or anon_name. */
> +static inline struct vm_area_struct
> +*vma_modify_flags_name(struct vma_iterator *vmi,
> + struct vm_area_struct *prev,
> + struct vm_area_struct *vma,
> + unsigned long start,
> + unsigned long end,
> + unsigned long new_flags,
> + struct anon_vma_name *new_name)
> +{
> + return vma_modify(vmi, prev, vma, start, end, new_flags,
> + vma_policy(vma), vma->vm_userfaultfd_ctx, new_name);
> +}
> +
> +/* We are about to modify the VMA's memory policy. */
> +static inline struct vm_area_struct
> +*vma_modify_policy(struct vma_iterator *vmi,
> + struct vm_area_struct *prev,
> + struct vm_area_struct *vma,
> + unsigned long start, unsigned long end,
> + struct mempolicy *new_pol)
> +{
> + return vma_modify(vmi, prev, vma, start, end, vma->vm_flags,
> + new_pol, vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> +}
> +
> +/* We are about to modify the VMA's flags and/or uffd context. */
> +static inline struct vm_area_struct
> +*vma_modify_flags_uffd(struct vma_iterator *vmi,
> + struct vm_area_struct *prev,
> + struct vm_area_struct *vma,
> + unsigned long start, unsigned long end,
> + unsigned long new_flags,
> + struct vm_userfaultfd_ctx new_ctx)
> +{
> + return vma_modify(vmi, prev, vma, start, end, new_flags,
> + vma_policy(vma), new_ctx, anon_vma_name(vma));
> +}
>
> static inline int check_data_rlimit(unsigned long rlim,
> unsigned long new,
> diff --git a/mm/madvise.c b/mm/madvise.c
> index a4a20de50494..801d3c1bb7b3 100644
> --- a/mm/madvise.c
> +++ b/mm/madvise.c
> @@ -141,7 +141,7 @@ static int madvise_update_vma(struct vm_area_struct *vma,
> {
> struct mm_struct *mm = vma->vm_mm;
> int error;
> - pgoff_t pgoff;
> + struct vm_area_struct *merged;
> VMA_ITERATOR(vmi, mm, start);
>
> if (new_flags == vma->vm_flags && anon_vma_name_eq(anon_vma_name(vma), anon_name)) {
> @@ -149,30 +149,16 @@ static int madvise_update_vma(struct vm_area_struct *vma,
> return 0;
> }
>
> - pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> - *prev = vma_merge(&vmi, mm, *prev, start, end, new_flags,
> - vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> - vma->vm_userfaultfd_ctx, anon_name);
> - if (*prev) {
> - vma = *prev;
> - goto success;
> - }
> -
> - *prev = vma;
> -
> - if (start != vma->vm_start) {
> - error = split_vma(&vmi, vma, start, 1);
> - if (error)
> - return error;
> - }
> + merged = vma_modify_flags_name(&vmi, *prev, vma, start, end, new_flags,
> + anon_name);
> + if (IS_ERR(merged))
> + return PTR_ERR(merged);
>
> - if (end != vma->vm_end) {
> - error = split_vma(&vmi, vma, end, 0);
> - if (error)
> - return error;
> - }
> + if (merged)
> + vma = *prev = merged;
> + else
> + *prev = vma;
>
> -success:
> /* vm_flags is protected by the mmap_lock held in write mode. */
> vma_start_write(vma);
> vm_flags_reset(vma, new_flags);
> diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> index b01922e88548..6b2e99db6dd5 100644
> --- a/mm/mempolicy.c
> +++ b/mm/mempolicy.c
> @@ -786,8 +786,6 @@ static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
> {
> struct vm_area_struct *merged;
> unsigned long vmstart, vmend;
> - pgoff_t pgoff;
> - int err;
>
> vmend = min(end, vma->vm_end);
> if (start > vma->vm_start) {
> @@ -802,27 +800,15 @@ static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
> return 0;
> }
>
> - pgoff = vma->vm_pgoff + ((vmstart - vma->vm_start) >> PAGE_SHIFT);
> - merged = vma_merge(vmi, vma->vm_mm, *prev, vmstart, vmend, vma->vm_flags,
> - vma->anon_vma, vma->vm_file, pgoff, new_pol,
> - vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> + merged = vma_modify_policy(vmi, *prev, vma, vmstart, vmend, new_pol);
> + if (IS_ERR(merged))
> + return PTR_ERR(merged);
> +
> if (merged) {
> *prev = merged;
> return vma_replace_policy(merged, new_pol);
> }
>
> - if (vma->vm_start != vmstart) {
> - err = split_vma(vmi, vma, vmstart, 1);
> - if (err)
> - return err;
> - }
> -
> - if (vma->vm_end != vmend) {
> - err = split_vma(vmi, vma, vmend, 0);
> - if (err)
> - return err;
> - }
> -
> *prev = vma;
> return vma_replace_policy(vma, new_pol);
> }
> diff --git a/mm/mlock.c b/mm/mlock.c
> index 42b6865f8f82..ae83a33c387e 100644
> --- a/mm/mlock.c
> +++ b/mm/mlock.c
> @@ -476,10 +476,10 @@ static int mlock_fixup(struct vma_iterator *vmi, struct vm_area_struct *vma,
> unsigned long end, vm_flags_t newflags)
> {
> struct mm_struct *mm = vma->vm_mm;
> - pgoff_t pgoff;
> int nr_pages;
> int ret = 0;
> vm_flags_t oldflags = vma->vm_flags;
> + struct vm_area_struct *merged;
>
> if (newflags == oldflags || (oldflags & VM_SPECIAL) ||
> is_vm_hugetlb_page(vma) || vma == get_gate_vma(current->mm) ||
> @@ -487,28 +487,15 @@ static int mlock_fixup(struct vma_iterator *vmi, struct vm_area_struct *vma,
> /* don't set VM_LOCKED or VM_LOCKONFAULT and don't count */
> goto out;
>
> - pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> - *prev = vma_merge(vmi, mm, *prev, start, end, newflags,
> - vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> - vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> - if (*prev) {
> - vma = *prev;
> - goto success;
> - }
> -
> - if (start != vma->vm_start) {
> - ret = split_vma(vmi, vma, start, 1);
> - if (ret)
> - goto out;
> + merged = vma_modify_flags(vmi, *prev, vma, start, end, newflags);
> + if (IS_ERR(merged)) {
> + ret = PTR_ERR(merged);
> + goto out;
> }
>
> - if (end != vma->vm_end) {
> - ret = split_vma(vmi, vma, end, 0);
> - if (ret)
> - goto out;
> - }
> + if (merged)
> + vma = *prev = merged;
>
> -success:
> /*
> * Keep track of amount of locked VM.
> */
> diff --git a/mm/mmap.c b/mm/mmap.c
> index 673429ee8a9e..22d968affc07 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -2437,6 +2437,51 @@ int split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
> return __split_vma(vmi, vma, addr, new_below);
> }
>
> +/*
> + * We are about to modify one or multiple of a VMA's flags, policy, userfaultfd
> + * context and anonymous VMA name within the range [start, end).
> + *
> + * As a result, we might be able to merge the newly modified VMA range with an
> + * adjacent VMA with identical properties.
> + *
> + * If no merge is possible and the range does not span the entirety of the VMA,
> + * we then need to split the VMA to accommodate the change.
> + */
> +struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
> + struct vm_area_struct *prev,
> + struct vm_area_struct *vma,
> + unsigned long start, unsigned long end,
> + unsigned long vm_flags,
> + struct mempolicy *policy,
> + struct vm_userfaultfd_ctx uffd_ctx,
> + struct anon_vma_name *anon_name)
> +{
> + pgoff_t pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> + struct vm_area_struct *merged;
> +
> + merged = vma_merge(vmi, vma->vm_mm, prev, start, end, vm_flags,
> + vma->anon_vma, vma->vm_file, pgoff, policy,
> + uffd_ctx, anon_name);
> + if (merged)
> + return merged;
> +
> + if (vma->vm_start < start) {
> + int err = split_vma(vmi, vma, start, 1);
> +
> + if (err)
> + return ERR_PTR(err);
> + }
> +
> + if (vma->vm_end > end) {
> + int err = split_vma(vmi, vma, end, 0);
> +
> + if (err)
> + return ERR_PTR(err);
> + }
> +
> + return NULL;
> +}
> +
> /*
> * do_vmi_align_munmap() - munmap the aligned region from @start to @end.
> * @vmi: The vma iterator
> diff --git a/mm/mprotect.c b/mm/mprotect.c
> index b94fbb45d5c7..6f85d99682ab 100644
> --- a/mm/mprotect.c
> +++ b/mm/mprotect.c
> @@ -581,7 +581,7 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
> long nrpages = (end - start) >> PAGE_SHIFT;
> unsigned int mm_cp_flags = 0;
> unsigned long charged = 0;
> - pgoff_t pgoff;
> + struct vm_area_struct *merged;
> int error;
>
> if (newflags == oldflags) {
> @@ -625,34 +625,19 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
> }
> }
>
> - /*
> - * First try to merge with previous and/or next vma.
> - */
> - pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> - *pprev = vma_merge(vmi, mm, *pprev, start, end, newflags,
> - vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> - vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> - if (*pprev) {
> - vma = *pprev;
> - VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
> - goto success;
> + merged = vma_modify_flags(vmi, *pprev, vma, start, end, newflags);
> + if (IS_ERR(merged)) {
> + error = PTR_ERR(merged);
> + goto fail;
> }
>
> - *pprev = vma;
> -
> - if (start != vma->vm_start) {
> - error = split_vma(vmi, vma, start, 1);
> - if (error)
> - goto fail;
> - }
> -
> - if (end != vma->vm_end) {
> - error = split_vma(vmi, vma, end, 0);
> - if (error)
> - goto fail;
> + if (merged) {
> + vma = *pprev = merged;
> + VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
> + } else {
> + *pprev = vma;
> }
>
> -success:
> /*
> * vm_flags and vm_page_prot are protected by the mmap_lock
> * held in write mode.
> --
> 2.42.0
>