Re: [PATCH v2 3/4] arm64: mm: Don't remap pgtables for allocate vs populate

From: Ryan Roberts
Date: Fri Apr 12 2024 - 03:53:51 EST


Hi Mark,

[...]

> Does something like the below look ok to you? The trade-off performance-wise is
> that late uses will still use the fixmap, and will redundantly zero the tables,
> but the logic remains fairly simple, and I suspect the overhead for late
> allocations might not matter since the bulk of late changes are non-allocating.
>
> Mark
>
> ---->8-----
> diff --git a/arch/arm64/include/asm/pgtable.h b/arch/arm64/include/asm/pgtable.h
> index 105a95a8845c5..1eecf87021bd0 100644
> --- a/arch/arm64/include/asm/pgtable.h
> +++ b/arch/arm64/include/asm/pgtable.h
> @@ -1010,6 +1010,8 @@ static inline p4d_t *p4d_offset_kimg(pgd_t *pgdp, u64 addr)
>
> static inline bool pgtable_l5_enabled(void) { return false; }
>
> +#define p4d_index(addr) (((addr) >> P4D_SHIFT) & (PTRS_PER_P4D - 1)
> +
> /* Match p4d_offset folding in <asm/generic/pgtable-nop4d.h> */
> #define p4d_set_fixmap(addr) NULL
> #define p4d_set_fixmap_offset(p4dp, addr) ((p4d_t *)p4dp)
> diff --git a/arch/arm64/mm/mmu.c b/arch/arm64/mm/mmu.c
> index dc86dceb0efe6..4b944ef8f618c 100644
> --- a/arch/arm64/mm/mmu.c
> +++ b/arch/arm64/mm/mmu.c
> @@ -109,28 +109,12 @@ EXPORT_SYMBOL(phys_mem_access_prot);
> static phys_addr_t __init early_pgtable_alloc(int shift)
> {
> phys_addr_t phys;
> - void *ptr;
>
> phys = memblock_phys_alloc_range(PAGE_SIZE, PAGE_SIZE, 0,
> MEMBLOCK_ALLOC_NOLEAKTRACE);
> if (!phys)
> panic("Failed to allocate page table page\n");
>
> - /*
> - * The FIX_{PGD,PUD,PMD} slots may be in active use, but the FIX_PTE
> - * slot will be free, so we can (ab)use the FIX_PTE slot to initialise
> - * any level of table.
> - */
> - ptr = pte_set_fixmap(phys);
> -
> - memset(ptr, 0, PAGE_SIZE);
> -
> - /*
> - * Implicit barriers also ensure the zeroed page is visible to the page
> - * table walker
> - */
> - pte_clear_fixmap();
> -
> return phys;
> }
>
> @@ -172,6 +156,14 @@ bool pgattr_change_is_safe(u64 old, u64 new)
> return ((old ^ new) & ~mask) == 0;
> }
>
> +static void init_clear_pgtable(void *table)
> +{
> + clear_page(table);
> +
> + /* Ensure the zeroing is observed by page table walks. */
> + dsb(ishst);
> +}
> +
> static pte_t *init_pte(pte_t *ptep, unsigned long addr, unsigned long end,
> phys_addr_t phys, pgprot_t prot)
> {
> @@ -216,12 +208,18 @@ static void alloc_init_cont_pte(pmd_t *pmdp, unsigned long addr,
> pmdval |= PMD_TABLE_PXN;
> BUG_ON(!pgtable_alloc);
> pte_phys = pgtable_alloc(PAGE_SHIFT);
> +
> + ptep = pte_set_fixmap(pte_phys);
> + init_clear_pgtable(ptep);
> +
> __pmd_populate(pmdp, pte_phys, pmdval);
> pmd = READ_ONCE(*pmdp);
> + } else {
> + ptep = pte_set_fixmap(pmd_page_paddr(pmd));
> }
> BUG_ON(pmd_bad(pmd));
>
> - ptep = pte_set_fixmap_offset(pmdp, addr);
> + ptep += pte_index(addr);
> do {
> pgprot_t __prot = prot;
>
> @@ -303,12 +301,18 @@ static void alloc_init_cont_pmd(pud_t *pudp, unsigned long addr,
> pudval |= PUD_TABLE_PXN;
> BUG_ON(!pgtable_alloc);
> pmd_phys = pgtable_alloc(PMD_SHIFT);
> +
> + pmdp = pmd_set_fixmap(pmd_phys);
> + init_clear_pgtable(pmdp);
> +
> __pud_populate(pudp, pmd_phys, pudval);
> pud = READ_ONCE(*pudp);
> + } else {
> + pmdp = pmd_set_fixmap(pud_page_paddr(pud));
> }
> BUG_ON(pud_bad(pud));
>
> - pmdp = pmd_set_fixmap_offset(pudp, addr);
> + pmdp += pmd_index(addr);
> do {
> pgprot_t __prot = prot;
>
> @@ -345,12 +349,18 @@ static void alloc_init_pud(p4d_t *p4dp, unsigned long addr, unsigned long end,
> p4dval |= P4D_TABLE_PXN;
> BUG_ON(!pgtable_alloc);
> pud_phys = pgtable_alloc(PUD_SHIFT);
> +
> + pudp = pud_set_fixmap(pud_phys);
> + init_clear_pgtable(pudp);
> +
> __p4d_populate(p4dp, pud_phys, p4dval);
> p4d = READ_ONCE(*p4dp);
> + } else {
> + pudp = pud_set_fixmap(p4d_page_paddr(p4d));

With this change I end up in pgtable folding hell. pXX_set_fixmap() is defined
as NULL when the level is folded (and pXX_page_paddr() is not defined at all).
So it all compiles, but doesn't boot.

I think the simplest approach is to follow this pattern:

----8<----
@@ -340,12 +338,15 @@ static void alloc_init_pud(p4d_t *p4dp, unsigned long
addr, unsigned long end,
p4dval |= P4D_TABLE_PXN;
BUG_ON(!pgtable_alloc);
pud_phys = pgtable_alloc(PUD_SHIFT);
+ pudp = pud_set_fixmap(pud_phys);
+ init_clear_pgtable(pudp);
+ pudp += pud_index(addr);
__p4d_populate(p4dp, pud_phys, p4dval);
- p4d = READ_ONCE(*p4dp);
+ } else {
+ BUG_ON(p4d_bad(p4d));
+ pudp = pud_set_fixmap_offset(p4dp, addr);
}
- BUG_ON(p4d_bad(p4d));

- pudp = pud_set_fixmap_offset(p4dp, addr);
do {
pud_t old_pud = READ_ONCE(*pudp);
----8<----

For the map case, we continue to use pud_set_fixmap_offset() which is always
defined (and always works correctly).

Note also that the previously unconditional BUG_ON needs to be prior to the
fixmap call to be useful, and its really only valuable in the map case because
for the alloc case we are the ones setting the p4d so we already know its not
bad. This means we don't need the READ_ONCE() in the alloc case.

Shout if you disagree.

Thanks,
Ryan

> }
> BUG_ON(p4d_bad(p4d));
>
> - pudp = pud_set_fixmap_offset(p4dp, addr);
> + pudp += pud_index(addr);
> do {
> pud_t old_pud = READ_ONCE(*pudp);
>
> @@ -400,12 +410,18 @@ static void alloc_init_p4d(pgd_t *pgdp, unsigned long addr, unsigned long end,
> pgdval |= PGD_TABLE_PXN;
> BUG_ON(!pgtable_alloc);
> p4d_phys = pgtable_alloc(P4D_SHIFT);
> +
> + p4dp = p4d_set_fixmap(p4d_phys);
> + init_clear_pgtable(p4dp);
> +
> __pgd_populate(pgdp, p4d_phys, pgdval);
> pgd = READ_ONCE(*pgdp);
> + } else {
> + p4dp = p4d_set_fixmap(pgd_page_paddr(pgd));
> }
> BUG_ON(pgd_bad(pgd));
>
> - p4dp = p4d_set_fixmap_offset(pgdp, addr);
> + p4dp += p4d_index(addr);
> do {
> p4d_t old_p4d = READ_ONCE(*p4dp);
>
> @@ -475,8 +491,6 @@ static phys_addr_t __pgd_pgtable_alloc(int shift)
> void *ptr = (void *)__get_free_page(GFP_PGTABLE_KERNEL);
> BUG_ON(!ptr);
>
> - /* Ensure the zeroed page is visible to the page table walker */
> - dsb(ishst);
> return __pa(ptr);
> }
>