[PATCH v4 22/25] device/dax: Properly refcount device dax pages when mapping

From: Alistair Popple
Date: Tue Dec 17 2024 - 00:25:31 EST


Device DAX pages are currently not reference counted when mapped,
instead relying on the devmap PTE bit to ensure mapping code will not
get/put references. This requires special handling in various page
table walkers, particularly GUP, to manage references on the
underlying pgmap to ensure the pages remain valid.

However there is no reason these pages can't be refcounted properly at
map time. Doning so eliminates the need for the devmap PTE bit,
freeing up a precious PTE bit. It also simplifies GUP as it no longer
needs to manage the special pgmap references and can instead just
treat the pages normally as defined by vm_normal_page().

Signed-off-by: Alistair Popple <apopple@xxxxxxxxxx>
---
drivers/dax/device.c | 15 +++++++++------
mm/memremap.c | 13 ++++++-------
2 files changed, 15 insertions(+), 13 deletions(-)

diff --git a/drivers/dax/device.c b/drivers/dax/device.c
index 6d74e62..fd22dbf 100644
--- a/drivers/dax/device.c
+++ b/drivers/dax/device.c
@@ -126,11 +126,12 @@ static vm_fault_t __dev_dax_pte_fault(struct dev_dax *dev_dax,
return VM_FAULT_SIGBUS;
}

- pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
+ pfn = phys_to_pfn_t(phys, 0);

dax_set_mapping(vmf, pfn, fault_size);

- return vmf_insert_mixed(vmf->vma, vmf->address, pfn);
+ return vmf_insert_page_mkwrite(vmf, pfn_t_to_page(pfn),
+ vmf->flags & FAULT_FLAG_WRITE);
}

static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
@@ -169,11 +170,12 @@ static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
return VM_FAULT_SIGBUS;
}

- pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
+ pfn = phys_to_pfn_t(phys, 0);

dax_set_mapping(vmf, pfn, fault_size);

- return vmf_insert_pfn_pmd(vmf, pfn, vmf->flags & FAULT_FLAG_WRITE);
+ return vmf_insert_folio_pmd(vmf, page_folio(pfn_t_to_page(pfn)),
+ vmf->flags & FAULT_FLAG_WRITE);
}

#ifdef CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD
@@ -214,11 +216,12 @@ static vm_fault_t __dev_dax_pud_fault(struct dev_dax *dev_dax,
return VM_FAULT_SIGBUS;
}

- pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
+ pfn = phys_to_pfn_t(phys, 0);

dax_set_mapping(vmf, pfn, fault_size);

- return vmf_insert_pfn_pud(vmf, pfn, vmf->flags & FAULT_FLAG_WRITE);
+ return vmf_insert_folio_pud(vmf, page_folio(pfn_t_to_page(pfn)),
+ vmf->flags & FAULT_FLAG_WRITE);
}
#else
static vm_fault_t __dev_dax_pud_fault(struct dev_dax *dev_dax,
diff --git a/mm/memremap.c b/mm/memremap.c
index 9a8879b..532a52a 100644
--- a/mm/memremap.c
+++ b/mm/memremap.c
@@ -460,11 +460,10 @@ void free_zone_device_folio(struct folio *folio)
{
struct dev_pagemap *pgmap = folio->pgmap;

- if (WARN_ON_ONCE(!pgmap->ops))
- return;
-
- if (WARN_ON_ONCE(pgmap->type != MEMORY_DEVICE_FS_DAX &&
- !pgmap->ops->page_free))
+ if (WARN_ON_ONCE((!pgmap->ops &&
+ pgmap->type != MEMORY_DEVICE_GENERIC) ||
+ (pgmap->ops && !pgmap->ops->page_free &&
+ pgmap->type != MEMORY_DEVICE_FS_DAX)))
return;

mem_cgroup_uncharge(folio);
@@ -494,7 +493,8 @@ void free_zone_device_folio(struct folio *folio)
* zero which indicating the page has been removed from the file
* system mapping.
*/
- if (pgmap->type != MEMORY_DEVICE_FS_DAX)
+ if (pgmap->type != MEMORY_DEVICE_FS_DAX &&
+ pgmap->type != MEMORY_DEVICE_GENERIC)
folio->mapping = NULL;

switch (pgmap->type) {
@@ -509,7 +509,6 @@ void free_zone_device_folio(struct folio *folio)
* Reset the refcount to 1 to prepare for handing out the page
* again.
*/
- pgmap->ops->page_free(folio_page(folio, 0));
folio_set_count(folio, 1);
break;

--
git-series 0.9.1