[RFC PATCH 3/4] mm/memory-tiers: register CXL nodes to socket-aware packages via initiator

From: Rakie Kim

Date: Mon Mar 16 2026 - 01:13:52 EST


CXL memory nodes appear without an explicit socket association.
Relying on plain NUMA distance does not convey which physical package
(CPU socket) they should belong to, which in turn makes locality-aware
placement ambiguous.

This change introduces a registration path that binds a CXL memory node
to a socket-aware "memory package" using an initiator CPU node. The
initiator is the CPU nid that best represents the host-side attachment
of the region (e.g., the CPU closest to the region’s target). By using
this nid to resolve the package, the CXL node is grouped with the CPUs
it actually services.

The flow is:
- Determine an initiator CPU nid for the CXL region.
- Register the CXL node with the package layer using that initiator.

This provides a deterministic and topology-consistent way to place CXL
nodes into the correct socket grouping, reducing the risk of inadvertent
cross-socket choices that distance alone cannot prevent.

Signed-off-by: Rakie Kim <rakie.kim@xxxxxx>
---
drivers/cxl/core/region.c | 46 +++++++++++++++++++++++++++++++++++++++
drivers/cxl/cxl.h | 1 +
drivers/dax/kmem.c | 2 ++
3 files changed, 49 insertions(+)

diff --git a/drivers/cxl/core/region.c b/drivers/cxl/core/region.c
index 5bd1213737fa..2733e0d465cc 100644
--- a/drivers/cxl/core/region.c
+++ b/drivers/cxl/core/region.c
@@ -2570,6 +2570,47 @@ static int cxl_region_calculate_adistance(struct notifier_block *nb,
return NOTIFY_STOP;
}

+static int cxl_region_find_nearest_node(struct cxl_region *cxlr)
+{
+ struct cxl_region_params *p = &cxlr->params;
+ struct cxl_endpoint_decoder *cxled = NULL;
+ struct cxl_memdev *cxlmd = NULL;
+ int i, numa_node;
+
+ for (i = 0; i < p->nr_targets; i++) {
+ cxled = p->targets[i];
+ cxlmd = cxled_to_memdev(cxled);
+ numa_node = dev_to_node(&cxlmd->dev);
+ if (numa_node != NUMA_NO_NODE)
+ return numa_node;
+ }
+ return NUMA_NO_NODE;
+}
+
+static int cxl_region_add_package_node(struct notifier_block *nb,
+ unsigned long dax_nid, void *data)
+{
+ int region_nid, nearest_nid, ret;
+ struct cxl_region *cxlr = container_of(nb, struct cxl_region, package_notifier);
+
+ region_nid = phys_to_target_node(cxlr->params.res->start);
+ if (region_nid != dax_nid)
+ return NOTIFY_DONE;
+
+ nearest_nid = cxl_region_find_nearest_node(cxlr);
+ if (nearest_nid == NUMA_NO_NODE)
+ return NOTIFY_DONE;
+
+ ret = mp_add_package_node_by_initiator(dax_nid, nearest_nid);
+ if (ret) {
+ dev_info(&cxlr->dev, "failed add package node (%lu), nearest_nid (%d)\n",
+ dax_nid, nearest_nid);
+ return NOTIFY_DONE;
+ }
+
+ return NOTIFY_OK;
+}
+
/**
* devm_cxl_add_region - Adds a region to a decoder
* @cxlrd: root decoder
@@ -3788,6 +3829,7 @@ static void shutdown_notifiers(void *_cxlr)

unregister_node_notifier(&cxlr->node_notifier);
unregister_mt_adistance_algorithm(&cxlr->adist_notifier);
+ unregister_mp_package_notifier(&cxlr->package_notifier);
}

static void remove_debugfs(void *dentry)
@@ -3940,6 +3982,10 @@ static int cxl_region_probe(struct device *dev)
cxlr->adist_notifier.priority = 100;
register_mt_adistance_algorithm(&cxlr->adist_notifier);

+ cxlr->package_notifier.notifier_call = cxl_region_add_package_node;
+ cxlr->package_notifier.priority = 100;
+ register_mp_package_notifier(&cxlr->package_notifier);
+
rc = devm_add_action_or_reset(&cxlr->dev, shutdown_notifiers, cxlr);
if (rc)
return rc;
diff --git a/drivers/cxl/cxl.h b/drivers/cxl/cxl.h
index ba17fa86d249..6b6653e31135 100644
--- a/drivers/cxl/cxl.h
+++ b/drivers/cxl/cxl.h
@@ -551,6 +551,7 @@ struct cxl_region {
struct access_coordinate coord[ACCESS_COORDINATE_MAX];
struct notifier_block node_notifier;
struct notifier_block adist_notifier;
+ struct notifier_block package_notifier;
};

struct cxl_nvdimm_bridge {
diff --git a/drivers/dax/kmem.c b/drivers/dax/kmem.c
index c036e4d0b610..32ee66b82cd3 100644
--- a/drivers/dax/kmem.c
+++ b/drivers/dax/kmem.c
@@ -94,6 +94,8 @@ static int dev_dax_kmem_probe(struct dev_dax *dev_dax)
if (IS_ERR(mtype))
return PTR_ERR(mtype);

+ mp_probe_package_id(numa_node);
+
for (i = 0; i < dev_dax->nr_range; i++) {
struct range range;

--
2.34.1