Re: [PATCH] vfio/pci: make the vfio_pci_mmap_fault reentrant

From: Alex Williamson
Date: Tue Mar 09 2021 - 14:57:48 EST


On Tue, 9 Mar 2021 12:26:07 -0700
Alex Williamson <alex.williamson@xxxxxxxxxx> wrote:

> On Tue, 9 Mar 2021 13:47:39 -0500
> Peter Xu <peterx@xxxxxxxxxx> wrote:
>
> > On Tue, Mar 09, 2021 at 12:40:04PM -0400, Jason Gunthorpe wrote:
> > > On Tue, Mar 09, 2021 at 08:29:51AM -0700, Alex Williamson wrote:
> > > > On Tue, 9 Mar 2021 08:46:09 -0400
> > > > Jason Gunthorpe <jgg@xxxxxxxxxx> wrote:
> > > >
> > > > > On Tue, Mar 09, 2021 at 03:49:09AM +0000, Zengtao (B) wrote:
> > > > > > Hi guys:
> > > > > >
> > > > > > Thanks for the helpful comments, after rethinking the issue, I have proposed
> > > > > > the following change:
> > > > > > 1. follow_pte instead of follow_pfn.
> > > > >
> > > > > Still no on follow_pfn, you don't need it once you use vmf_insert_pfn
> > > >
> > > > vmf_insert_pfn() only solves the BUG_ON, follow_pte() is being used
> > > > here to determine whether the translation is already present to avoid
> > > > both duplicate work in inserting the translation and allocating a
> > > > duplicate vma tracking structure.
> > >
> > > Oh.. Doing something stateful in fault is not nice at all
> > >
> > > I would rather see __vfio_pci_add_vma() search the vma_list for dups
> > > than call follow_pfn/pte..
> >
> > It seems to me that searching vma list is still the simplest way to fix the
> > problem for the current code base. I see io_remap_pfn_range() is also used in
> > the new series - maybe that'll need to be moved to where PCI_COMMAND_MEMORY got
> > turned on/off in the new series (I just noticed remap_pfn_range modifies vma
> > flags..), as you suggested in the other email.
>
>
> In the new series, I think the fault handler becomes (untested):
>
> static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf)
> {
> struct vm_area_struct *vma = vmf->vma;
> struct vfio_pci_device *vdev = vma->vm_private_data;
> unsigned long base_pfn, pgoff;
> vm_fault_t ret = VM_FAULT_SIGBUS;
>
> if (vfio_pci_bar_vma_to_pfn(vma, &base_pfn))
> return ret;
>
> pgoff = (vmf->address - vma->vm_start) >> PAGE_SHIFT;
>
> down_read(&vdev->memory_lock);
>
> if (__vfio_pci_memory_enabled(vdev))
> ret = vmf_insert_pfn(vma, vmf->address, pgoff + base_pfn);
>
> up_read(&vdev->memory_lock);
>
> return ret;
> }

And I think this is what we end up with for the current code base:

diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index 65e7e6b44578..2f247ab18c66 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -1568,19 +1568,24 @@ void vfio_pci_memory_unlock_and_restore(struct vfio_pci_device *vdev, u16 cmd)
}

/* Caller holds vma_lock */
-static int __vfio_pci_add_vma(struct vfio_pci_device *vdev,
- struct vm_area_struct *vma)
+struct vfio_pci_mmap_vma *__vfio_pci_add_vma(struct vfio_pci_device *vdev,
+ struct vm_area_struct *vma)
{
struct vfio_pci_mmap_vma *mmap_vma;

+ list_for_each_entry(mmap_vma, &vdev->vma_list, vma_next) {
+ if (mmap_vma->vma == vma)
+ return ERR_PTR(-EEXIST);
+ }
+
mmap_vma = kmalloc(sizeof(*mmap_vma), GFP_KERNEL);
if (!mmap_vma)
- return -ENOMEM;
+ return ERR_PTR(-ENOMEM);

mmap_vma->vma = vma;
list_add(&mmap_vma->vma_next, &vdev->vma_list);

- return 0;
+ return mmap_vma;
}

/*
@@ -1612,30 +1617,39 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf)
{
struct vm_area_struct *vma = vmf->vma;
struct vfio_pci_device *vdev = vma->vm_private_data;
- vm_fault_t ret = VM_FAULT_NOPAGE;
+ struct vfio_pci_mmap_vma *mmap_vma;
+ unsigned long vaddr, pfn;
+ vm_fault_t ret;

mutex_lock(&vdev->vma_lock);
down_read(&vdev->memory_lock);

if (!__vfio_pci_memory_enabled(vdev)) {
ret = VM_FAULT_SIGBUS;
- mutex_unlock(&vdev->vma_lock);
goto up_out;
}

- if (__vfio_pci_add_vma(vdev, vma)) {
- ret = VM_FAULT_OOM;
- mutex_unlock(&vdev->vma_lock);
+ mmap_vma = __vfio_pci_add_vma(vdev, vma);
+ if (IS_ERR(mmap_vma)) {
+ /* A concurrent fault might have already inserted the page */
+ ret = (PTR_ERR(mmap_vma) == -EEXIST) ? VM_FAULT_NOPAGE :
+ VM_FAULT_OOM;
goto up_out;
}

- mutex_unlock(&vdev->vma_lock);
-
- if (io_remap_pfn_range(vma, vma->vm_start, vma->vm_pgoff,
- vma->vm_end - vma->vm_start, vma->vm_page_prot))
- ret = VM_FAULT_SIGBUS;
-
+ for (vaddr = vma->vm_start, pfn = vma->vm_pgoff;
+ vaddr < vma->vm_end; vaddr += PAGE_SIZE, pfn++) {
+ ret = vmf_insert_pfn(vma, vaddr, pfn);
+ if (ret != VM_FAULT_NOPAGE) {
+ zap_vma_ptes(vma, vma->vm_start,
+ vma->vm_end - vma->vm_start);
+ list_del(&mmap_vma->vma_next);
+ kfree(mmap_vma);
+ break;
+ }
+ }
up_out:
+ mutex_unlock(&vdev->vma_lock);
up_read(&vdev->memory_lock);
return ret;
}