[PATCH 2/3] mshv: Use hmm_range_fault_unlockable() for userfaultfd support

From: Stanislav Kinsburskii

Date: Thu Apr 30 2026 - 21:21:08 EST


Convert the mshv driver's HMM fault path to use
hmm_range_fault_unlockable() instead of hmm_range_fault(). This enables
userfaultfd-backed guest memory regions by allowing the mmap lock to be
dropped during page fault handling.

Extract the per-VMA walk into a dedicated mshv_region_hmm_fault_walk()
helper. The outer mshv_region_hmm_fault_and_lock() handles the do/while
restart loop: if the lock is dropped during a fault (userfaultfd resolution
or similar) or an invalidation occurs (-EBUSY), the function restarts the
entire walk from the beginning with a fresh notifier_seq, since the VMA
layout may have changed.

Signed-off-by: Stanislav Kinsburskii <skinsburskii@xxxxxxxxxxxxxxxxxxx>
---
drivers/hv/mshv_regions.c | 127 +++++++++++++++++++++++++++++++--------------
1 file changed, 87 insertions(+), 40 deletions(-)

diff --git a/drivers/hv/mshv_regions.c b/drivers/hv/mshv_regions.c
index d09940e88298e..05665446ca6d9 100644
--- a/drivers/hv/mshv_regions.c
+++ b/drivers/hv/mshv_regions.c
@@ -565,6 +565,75 @@ int mshv_region_get(struct mshv_region *region)
return kref_get_unless_zero(&region->mreg_refcount);
}

+/**
+ * mshv_region_hmm_fault_walk - Walk VMAs and fault in pages for a range
+ * @region : Pointer to the memory region structure
+ * @range : HMM range structure (caller sets notifier and notifier_seq)
+ * @start : Starting virtual address of the range to fault (inclusive)
+ * @end : Ending virtual address of the range to fault (exclusive)
+ * @pfns : Output array for page frame numbers with HMM flags
+ * @locked : Pointer to lock state; set to 0 if mmap lock was dropped
+ * @do_fault: If true, fault in missing pages; if false, snapshot only
+ *
+ * Iterates through VMAs covering [start, end), collecting page frame
+ * numbers via hmm_range_fault_unlockable() for each VMA segment.
+ * When @do_fault is true, missing pages are faulted in and write faults
+ * are requested only when both the VMA and the hypervisor mapping permit
+ * writes, to avoid breaking copy-on-write semantics on read-only mappings.
+ *
+ * Return: 0 on success, negative error code on failure.
+ */
+static int mshv_region_hmm_fault_walk(struct mshv_region *region,
+ struct hmm_range *range,
+ unsigned long start,
+ unsigned long end,
+ unsigned long *pfns,
+ int *locked,
+ bool do_fault)
+{
+ unsigned long cur_start = start;
+ unsigned long *cur_pfns = pfns;
+
+ while (cur_start < end) {
+ struct vm_area_struct *vma;
+
+ vma = vma_lookup(range->notifier->mm, cur_start);
+ if (!vma)
+ return -EFAULT;
+
+ range->hmm_pfns = cur_pfns;
+ range->start = cur_start;
+ range->end = min(vma->vm_end, end);
+ range->default_flags = 0;
+ if (do_fault) {
+ range->default_flags = HMM_PFN_REQ_FAULT;
+ /*
+ * Only request writable pages from HMM when
+ * both the VMA and the hypervisor mapping allow
+ * writes. Without this, hmm_range_fault() would
+ * trigger COW on read-only mappings (e.g. shared
+ * zero pages, file-backed pages), breaking
+ * copy-on-write semantics and potentially
+ * granting the guest write access to shared host
+ * pages.
+ */
+ if ((vma->vm_flags & VM_WRITE) &&
+ (region->hv_map_flags & HV_MAP_GPA_WRITABLE))
+ range->default_flags |= HMM_PFN_REQ_WRITE;
+ }
+
+ int ret = hmm_range_fault_unlockable(range, locked);
+
+ if (ret || !*locked)
+ return ret;
+
+ cur_start = range->end;
+ cur_pfns += (range->end - range->start) >> PAGE_SHIFT;
+ }
+
+ return 0;
+}
+
/**
* mshv_region_hmm_fault_and_lock - Fault in pages across VMAs and lock
* the memory region
@@ -575,11 +644,9 @@ int mshv_region_get(struct mshv_region *region)
* @do_fault: If true, fault in missing pages; if false, snapshot only
* pages already present in page tables
*
- * Iterates through VMAs covering [start, end), collecting page frame
- * numbers via hmm_range_fault() for each VMA segment. When @do_fault
- * is true, missing pages are faulted in and write faults are requested
- * only when both the VMA and the hypervisor mapping permit writes, to
- * avoid breaking copy-on-write semantics on read-only mappings.
+ * Faults in pages covering [start, end) and acquires region->mreg_mutex.
+ * If the mmap lock is dropped during the fault (e.g. by userfaultfd) or
+ * the mmu notifier sequence is invalidated, the entire walk is restarted.
*
* On success, returns with region->mreg_mutex held; the caller is
* responsible for releasing it. Returns -EBUSY if the mmu notifier
@@ -597,47 +664,27 @@ static int mshv_region_hmm_fault_and_lock(struct mshv_region *region,
.notifier = &region->mreg_mni,
};
struct mm_struct *mm = region->mreg_mni.mm;
+ int locked;
int ret;

- range.notifier_seq = mmu_interval_read_begin(range.notifier);
- mmap_read_lock(mm);
- while (start < end) {
- struct vm_area_struct *vma;
+ do {
+ range.notifier_seq = mmu_interval_read_begin(range.notifier);
+ locked = 1;
+ mmap_read_lock(mm);

- vma = vma_lookup(mm, start);
- if (!vma) {
- ret = -EFAULT;
- break;
- }
+ ret = mshv_region_hmm_fault_walk(region, &range, start, end,
+ pfns, &locked, do_fault);

- range.hmm_pfns = pfns;
- range.start = start;
- range.end = min(vma->vm_end, end);
- range.default_flags = 0;
- if (do_fault) {
- range.default_flags = HMM_PFN_REQ_FAULT;
- /*
- * Only request writable pages from HMM when both
- * the VMA and the hypervisor mapping allow writes.
- * Without this, hmm_range_fault() would trigger
- * COW on read-only mappings (e.g. shared zero
- * pages, file-backed pages), breaking
- * copy-on-write semantics and potentially granting
- * the guest write access to shared host pages.
- */
- if ((vma->vm_flags & VM_WRITE) &&
- (region->hv_map_flags & HV_MAP_GPA_WRITABLE))
- range.default_flags |= HMM_PFN_REQ_WRITE;
- }
+ if (locked)
+ mmap_read_unlock(mm);

- ret = hmm_range_fault(&range);
- if (ret)
- break;
+ /*
+ * If the lock was dropped (by userfaultfd or similar), restart
+ * the entire walk with a fresh notifier_seq since the VMA layout
+ * may have changed. Also restart on -EBUSY (invalidation).
+ */
+ } while (!locked || ret == -EBUSY);

- start = range.end;
- pfns += (range.end - range.start) >> PAGE_SHIFT;
- }
- mmap_read_unlock(mm);
if (ret)
return ret;