Re: [PATCH v4 43/66] userfaultfd: Use maple tree iterator to iterate VMAs

From: Vlastimil Babka
Date: Wed Jan 19 2022 - 11:26:10 EST


On 12/1/21 15:30, Liam Howlett wrote:
> From: "Liam R. Howlett" <Liam.Howlett@xxxxxxxxxx>
>
> Don't use the mm_struct linked list or the vma->vm_next in prep for removal
>
> Signed-off-by: Liam R. Howlett <Liam.Howlett@xxxxxxxxxx>
> ---
> fs/userfaultfd.c | 49 ++++++++++++++++++++++-------------
> include/linux/userfaultfd_k.h | 7 +++--
> mm/mmap.c | 12 ++++-----
> 3 files changed, 40 insertions(+), 28 deletions(-)
>
> diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> index 22bf14ab2d16..2880025598c7 100644
> --- a/fs/userfaultfd.c
> +++ b/fs/userfaultfd.c
> @@ -606,14 +606,16 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
> if (release_new_ctx) {
> struct vm_area_struct *vma;
> struct mm_struct *mm = release_new_ctx->mm;
> + VMA_ITERATOR(vmi, mm, 0);
>
> /* the various vma->vm_userfaultfd_ctx still points to it */
> mmap_write_lock(mm);
> - for (vma = mm->mmap; vma; vma = vma->vm_next)
> + for_each_vma(vmi, vma) {
> if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) {
> vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
> vma->vm_flags &= ~__VM_UFFD_FLAGS;
> }
> + }
> mmap_write_unlock(mm);
>
> userfaultfd_ctx_put(release_new_ctx);
> @@ -794,11 +796,13 @@ static bool has_unmap_ctx(struct userfaultfd_ctx *ctx, struct list_head *unmaps,
> return false;
> }
>
> -int userfaultfd_unmap_prep(struct vm_area_struct *vma,
> - unsigned long start, unsigned long end,
> - struct list_head *unmaps)
> +int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start,
> + unsigned long end, struct list_head *unmaps)
> {
> - for ( ; vma && vma->vm_start < end; vma = vma->vm_next) {
> + VMA_ITERATOR(vmi, mm, start);
> + struct vm_area_struct *vma;
> +
> + for_each_vma_range(vmi, vma, end) {
> struct userfaultfd_unmap_ctx *unmap_ctx;
> struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx;
>
> @@ -848,6 +852,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
> /* len == 0 means wake all */
> struct userfaultfd_wake_range range = { .len = 0, };
> unsigned long new_flags;
> + MA_STATE(mas, &mm->mm_mt, 0, 0);

Again, it looks like this could also be VMA_ITERATOR, consistent with the
one above?

>
> WRITE_ONCE(ctx->released, true);
>
> @@ -864,7 +869,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
> */
> mmap_write_lock(mm);
> prev = NULL;
> - for (vma = mm->mmap; vma; vma = vma->vm_next) {
> + mas_for_each(&mas, vma, ULONG_MAX) {
> cond_resched();
> BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
> !!(vma->vm_flags & __VM_UFFD_FLAGS));
> @@ -1281,6 +1286,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
> bool found;
> bool basic_ioctls;
> unsigned long start, end, vma_end;
> + MA_STATE(mas, &mm->mm_mt, 0, 0);
>
> user_uffdio_register = (struct uffdio_register __user *) arg;
>
> @@ -1323,7 +1329,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
> goto out;
>
> mmap_write_lock(mm);
> - vma = find_vma_prev(mm, start, &prev);
> + mas_set(&mas, start);
> + vma = mas_find(&mas, ULONG_MAX);
> if (!vma)
> goto out_unlock;
>
> @@ -1348,7 +1355,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
> */
> found = false;
> basic_ioctls = false;
> - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) {
> + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
> cond_resched();
>
> BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
> @@ -1408,8 +1415,10 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
> }
> BUG_ON(!found);
>
> - if (vma->vm_start < start)
> - prev = vma;
> + mas_set(&mas, start);
> + prev = mas_prev(&mas, 0);
> + if (prev != vma)
> + mas_next(&mas, ULONG_MAX);

Hmm non-commented tricky stuff...

>
> ret = 0;
> do {
> @@ -1466,8 +1475,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
> skip:
> prev = vma;
> start = vma->vm_end;
> - vma = vma->vm_next;
> - } while (vma && vma->vm_start < end);
> + vma = mas_next(&mas, end - 1);
> + } while (vma);
> out_unlock:
> mmap_write_unlock(mm);
> mmput(mm);
> @@ -1511,6 +1520,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
> bool found;
> unsigned long start, end, vma_end;
> const void __user *buf = (void __user *)arg;
> + MA_STATE(mas, &mm->mm_mt, 0, 0);
>
> ret = -EFAULT;
> if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
> @@ -1529,7 +1539,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
> goto out;
>
> mmap_write_lock(mm);
> - vma = find_vma_prev(mm, start, &prev);
> + mas_set(&mas, start);
> + vma = mas_find(&mas, ULONG_MAX);
> if (!vma)
> goto out_unlock;
>
> @@ -1554,7 +1565,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
> */
> found = false;
> ret = -EINVAL;
> - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) {
> + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
> cond_resched();
>
> BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
> @@ -1574,8 +1585,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
> }
> BUG_ON(!found);
>
> - if (vma->vm_start < start)
> - prev = vma;
> + mas_set(&mas, start);
> + prev = mas_prev(&mas, 0);
> + if (prev != vma)
> + mas_next(&mas, ULONG_MAX);

Same here.

>
> ret = 0;
> do {
> @@ -1640,8 +1653,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
> skip:
> prev = vma;
> start = vma->vm_end;
> - vma = vma->vm_next;
> - } while (vma && vma->vm_start < end);
> + vma = mas_next(&mas, end - 1);
> + } while (vma);
> out_unlock:
> mmap_write_unlock(mm);
> mmput(mm);
> diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h
> index 33cea484d1ad..e0b2ec2c20f2 100644
> --- a/include/linux/userfaultfd_k.h
> +++ b/include/linux/userfaultfd_k.h
> @@ -139,9 +139,8 @@ extern bool userfaultfd_remove(struct vm_area_struct *vma,
> unsigned long start,
> unsigned long end);
>
> -extern int userfaultfd_unmap_prep(struct vm_area_struct *vma,
> - unsigned long start, unsigned long end,
> - struct list_head *uf);
> +extern int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start,
> + unsigned long end, struct list_head *uf);
> extern void userfaultfd_unmap_complete(struct mm_struct *mm,
> struct list_head *uf);
>
> @@ -222,7 +221,7 @@ static inline bool userfaultfd_remove(struct vm_area_struct *vma,
> return true;
> }
>
> -static inline int userfaultfd_unmap_prep(struct vm_area_struct *vma,
> +static inline int userfaultfd_unmap_prep(struct mm_struct *mm,
> unsigned long start, unsigned long end,
> struct list_head *uf)
> {
> diff --git a/mm/mmap.c b/mm/mmap.c
> index 79b8494d83c6..dde74e0b195d 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -2449,7 +2449,7 @@ do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
> * split, despite we could. This is unlikely enough
> * failure that it's not worth optimizing it for.
> */
> - int error = userfaultfd_unmap_prep(vma, start, end, uf);
> + int error = userfaultfd_unmap_prep(mm, start, end, uf);
>
> if (error)
> return error;
> @@ -2938,10 +2938,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
> goto munmap_full_vma;
> }
>
> - vma_init(&unmap, mm);
> - unmap.vm_start = newbrk;
> - unmap.vm_end = oldbrk;
> - ret = userfaultfd_unmap_prep(&unmap, newbrk, oldbrk, uf);
> + ret = userfaultfd_unmap_prep(mm, newbrk, oldbrk, uf);
> if (ret)
> return ret;
> ret = 1;
> @@ -2954,6 +2951,9 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
> }
>
> vma->vm_end = newbrk;
> + vma_init(&unmap, mm);
> + unmap.vm_start = newbrk;
> + unmap.vm_end = oldbrk;
> if (vma_mas_remove(&unmap, mas))
> goto mas_store_fail;
>
> @@ -2963,7 +2963,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
> }
>
> unmap_pages = vma_pages(&unmap);
> - if (unmap.vm_flags & VM_LOCKED) {
> + if (vma->vm_flags & VM_LOCKED) {

Hmm is this an unrelated bug fix? As unmap didn't have any vm_flags set even
before this patch, right?

> mm->locked_vm -= unmap_pages;
> munlock_vma_pages_range(&unmap, newbrk, oldbrk);
> }