[PATCH 10/10] iommu/vt-d: Simplify calculate_psi_aligned_address()

From: Lu Baolu

Date: Thu Apr 02 2026 - 03:09:19 EST


From: Jason Gunthorpe <jgg@xxxxxxxxxx>

This is doing far too much math for the simple task of finding a power
of 2 that fully spans the given range. Use fls directly on the xor
which computes the common binary prefix.

Signed-off-by: Jason Gunthorpe <jgg@xxxxxxxxxx>
Link: https://lore.kernel.org/r/4-v1-f175e27af136+11647-iommupt_inv_vtd_jgg@xxxxxxxxxx
Signed-off-by: Lu Baolu <baolu.lu@xxxxxxxxxxxxxxx>
---
drivers/iommu/intel/cache.c | 49 ++++++++++++-------------------------
1 file changed, 16 insertions(+), 33 deletions(-)

diff --git a/drivers/iommu/intel/cache.c b/drivers/iommu/intel/cache.c
index be8410f0e841..54dd9f7323bd 100644
--- a/drivers/iommu/intel/cache.c
+++ b/drivers/iommu/intel/cache.c
@@ -254,37 +254,25 @@ void cache_tag_unassign_domain(struct dmar_domain *domain,
}

static unsigned long calculate_psi_aligned_address(unsigned long start,
- unsigned long end,
- unsigned long *_mask)
+ unsigned long last,
+ unsigned long *size_order)
{
- unsigned long pages = aligned_nrpages(start, end - start + 1);
- unsigned long aligned_pages = __roundup_pow_of_two(pages);
- unsigned long bitmask = aligned_pages - 1;
- unsigned long mask = ilog2(aligned_pages);
- unsigned long pfn = IOVA_PFN(start);
+ unsigned int sz_lg2;

- /*
- * PSI masks the low order bits of the base address. If the
- * address isn't aligned to the mask, then compute a mask value
- * needed to ensure the target range is flushed.
- */
- if (unlikely(bitmask & pfn)) {
- unsigned long end_pfn = pfn + pages - 1, shared_bits;
-
- /*
- * Since end_pfn <= pfn + bitmask, the only way bits
- * higher than bitmask can differ in pfn and end_pfn is
- * by carrying. This means after masking out bitmask,
- * high bits starting with the first set bit in
- * shared_bits are all equal in both pfn and end_pfn.
- */
- shared_bits = ~(pfn ^ end_pfn) & ~bitmask;
- mask = shared_bits ? __ffs(shared_bits) : MAX_AGAW_PFN_WIDTH;
+ /* Compute a sz_lg2 that spans start and last */
+ start &= GENMASK(BITS_PER_LONG - 1, VTD_PAGE_SHIFT);
+ sz_lg2 = fls_long(start ^ last);
+ if (sz_lg2 <= 12) {
+ *size_order = 0;
+ return start;
+ }
+ if (unlikely(sz_lg2 >= MAX_AGAW_PFN_WIDTH)) {
+ *size_order = MAX_AGAW_PFN_WIDTH;
+ return 0;
}

- *_mask = mask;
-
- return ALIGN_DOWN(start, VTD_PAGE_SIZE << mask);
+ *size_order = sz_lg2 - VTD_PAGE_SHIFT;
+ return start & GENMASK(BITS_PER_LONG - 1, sz_lg2);
}

static void qi_batch_flush_descs(struct intel_iommu *iommu, struct qi_batch *batch)
@@ -441,12 +429,7 @@ void cache_tag_flush_range(struct dmar_domain *domain, unsigned long start,
struct cache_tag *tag;
unsigned long flags;

- if (start == 0 && end == ULONG_MAX) {
- addr = 0;
- mask = MAX_AGAW_PFN_WIDTH;
- } else {
- addr = calculate_psi_aligned_address(start, end, &mask);
- }
+ addr = calculate_psi_aligned_address(start, end, &mask);

spin_lock_irqsave(&domain->cache_lock, flags);
list_for_each_entry(tag, &domain->cache_tags, node) {
--
2.43.0