[PATCH v4 16/49] userfaultfd: Use vma iterator
From: Liam R. Howlett
Date: Fri Jan 20 2023 - 11:53:39 EST
From: "Liam R. Howlett" <Liam.Howlett@xxxxxxxxxx>
Use the vma iterator so that the iterator can be invalidated or updated
to avoid each caller doing so.
Signed-off-by: Liam R. Howlett <Liam.Howlett@xxxxxxxxxx>
---
fs/userfaultfd.c | 87 ++++++++++++++++++------------------------------
1 file changed, 33 insertions(+), 54 deletions(-)
diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index 15a5bf765d43..4334bd35984d 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -883,7 +883,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
/* len == 0 means wake all */
struct userfaultfd_wake_range range = { .len = 0, };
unsigned long new_flags;
- MA_STATE(mas, &mm->mm_mt, 0, 0);
+ VMA_ITERATOR(vmi, mm, 0);
WRITE_ONCE(ctx->released, true);
@@ -900,7 +900,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
*/
mmap_write_lock(mm);
prev = NULL;
- mas_for_each(&mas, vma, ULONG_MAX) {
+ for_each_vma(vmi, vma) {
cond_resched();
BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
!!(vma->vm_flags & __VM_UFFD_FLAGS));
@@ -909,13 +909,12 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
continue;
}
new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
- prev = vma_merge(mm, prev, vma->vm_start, vma->vm_end,
+ prev = vmi_vma_merge(&vmi, mm, prev, vma->vm_start, vma->vm_end,
new_flags, vma->anon_vma,
vma->vm_file, vma->vm_pgoff,
vma_policy(vma),
NULL_VM_UFFD_CTX, anon_vma_name(vma));
if (prev) {
- mas_pause(&mas);
vma = prev;
} else {
prev = vma;
@@ -1302,7 +1301,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
bool found;
bool basic_ioctls;
unsigned long start, end, vma_end;
- MA_STATE(mas, &mm->mm_mt, 0, 0);
+ struct vma_iterator vmi;
user_uffdio_register = (struct uffdio_register __user *) arg;
@@ -1344,17 +1343,13 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
if (!mmget_not_zero(mm))
goto out;
+ ret = -EINVAL;
mmap_write_lock(mm);
- mas_set(&mas, start);
- vma = mas_find(&mas, ULONG_MAX);
+ vma_iter_init(&vmi, mm, start);
+ vma = vma_find(&vmi, end);
if (!vma)
goto out_unlock;
- /* check that there's at least one vma in the range */
- ret = -EINVAL;
- if (vma->vm_start >= end)
- goto out_unlock;
-
/*
* If the first vma contains huge pages, make sure start address
* is aligned to huge page size.
@@ -1371,7 +1366,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
*/
found = false;
basic_ioctls = false;
- for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
+ cur = vma;
+ do {
cond_resched();
BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1428,16 +1424,14 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
basic_ioctls = true;
found = true;
- }
+ } for_each_vma_range(vmi, cur, end);
BUG_ON(!found);
- mas_set(&mas, start);
- prev = mas_prev(&mas, 0);
- if (prev != vma)
- mas_next(&mas, ULONG_MAX);
+ vma_iter_set(&vmi, start);
+ prev = vma_prev(&vmi);
ret = 0;
- do {
+ for_each_vma_range(vmi, vma, end) {
cond_resched();
BUG_ON(!vma_can_userfault(vma, vm_flags));
@@ -1458,30 +1452,25 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
vma_end = min(end, vma->vm_end);
new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
- prev = vma_merge(mm, prev, start, vma_end, new_flags,
+ prev = vmi_vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
vma->anon_vma, vma->vm_file, vma->vm_pgoff,
vma_policy(vma),
((struct vm_userfaultfd_ctx){ ctx }),
anon_vma_name(vma));
if (prev) {
/* vma_merge() invalidated the mas */
- mas_pause(&mas);
vma = prev;
goto next;
}
if (vma->vm_start < start) {
- ret = split_vma(mm, vma, start, 1);
+ ret = vmi_split_vma(&vmi, mm, vma, start, 1);
if (ret)
break;
- /* split_vma() invalidated the mas */
- mas_pause(&mas);
}
if (vma->vm_end > end) {
- ret = split_vma(mm, vma, end, 0);
+ ret = vmi_split_vma(&vmi, mm, vma, end, 0);
if (ret)
break;
- /* split_vma() invalidated the mas */
- mas_pause(&mas);
}
next:
/*
@@ -1498,8 +1487,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
skip:
prev = vma;
start = vma->vm_end;
- vma = mas_next(&mas, end - 1);
- } while (vma);
+ }
+
out_unlock:
mmap_write_unlock(mm);
mmput(mm);
@@ -1543,7 +1532,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
bool found;
unsigned long start, end, vma_end;
const void __user *buf = (void __user *)arg;
- MA_STATE(mas, &mm->mm_mt, 0, 0);
+ struct vma_iterator vmi;
ret = -EFAULT;
if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
@@ -1562,14 +1551,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
goto out;
mmap_write_lock(mm);
- mas_set(&mas, start);
- vma = mas_find(&mas, ULONG_MAX);
- if (!vma)
- goto out_unlock;
-
- /* check that there's at least one vma in the range */
ret = -EINVAL;
- if (vma->vm_start >= end)
+ vma_iter_init(&vmi, mm, start);
+ vma = vma_find(&vmi, end);
+ if (!vma)
goto out_unlock;
/*
@@ -1587,8 +1572,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
* Search for not compatible vmas.
*/
found = false;
- ret = -EINVAL;
- for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
+ cur = vma;
+ do {
cond_resched();
BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1605,16 +1590,13 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
goto out_unlock;
found = true;
- }
+ } for_each_vma_range(vmi, cur, end);
BUG_ON(!found);
- mas_set(&mas, start);
- prev = mas_prev(&mas, 0);
- if (prev != vma)
- mas_next(&mas, ULONG_MAX);
-
+ vma_iter_set(&vmi, start);
+ prev = vma_prev(&vmi);
ret = 0;
- do {
+ for_each_vma_range(vmi, vma, end) {
cond_resched();
BUG_ON(!vma_can_userfault(vma, vma->vm_flags));
@@ -1650,26 +1632,23 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
uffd_wp_range(mm, vma, start, vma_end - start, false);
new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
- prev = vma_merge(mm, prev, start, vma_end, new_flags,
+ prev = vmi_vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
vma->anon_vma, vma->vm_file, vma->vm_pgoff,
vma_policy(vma),
NULL_VM_UFFD_CTX, anon_vma_name(vma));
if (prev) {
vma = prev;
- mas_pause(&mas);
goto next;
}
if (vma->vm_start < start) {
- ret = split_vma(mm, vma, start, 1);
+ ret = vmi_split_vma(&vmi, mm, vma, start, 1);
if (ret)
break;
- mas_pause(&mas);
}
if (vma->vm_end > end) {
- ret = split_vma(mm, vma, end, 0);
+ ret = vmi_split_vma(&vmi, mm, vma, end, 0);
if (ret)
break;
- mas_pause(&mas);
}
next:
/*
@@ -1683,8 +1662,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
skip:
prev = vma;
start = vma->vm_end;
- vma = mas_next(&mas, end - 1);
- } while (vma);
+ }
+
out_unlock:
mmap_write_unlock(mm);
mmput(mm);
--
2.35.1