[PATCH v3 5/8] iommufd: Associate fault object with iommufd_hw_pgtable

From: Lu Baolu
Date: Mon Jan 22 2024 - 02:46:29 EST


When allocating a user iommufd_hw_pagetable, the user space is allowed to
associate a fault object with the hw_pagetable by specifying the fault
object ID in the page table allocation data and setting the
IOMMU_HWPT_FAULT_ID_VALID flag bit.

On a successful return of hwpt allocation, the user can retrieve and
respond to page faults by reading and writing the file interface of the
fault object.

Once a fault object has been associated with a hwpt, the hwpt is
iopf-capable, indicated by fault_capable set to true. Attaching,
detaching, or replacing an iopf-capable hwpt to an RID or PASID will
differ from those that are not iopf-capable. The implementation of these
will be introduced in the next patch.

Signed-off-by: Lu Baolu <baolu.lu@xxxxxxxxxxxxxxx>
---
drivers/iommu/iommufd/iommufd_private.h | 11 ++++++++
include/uapi/linux/iommufd.h | 6 +++++
drivers/iommu/iommufd/fault.c | 14 ++++++++++
drivers/iommu/iommufd/hw_pagetable.c | 36 +++++++++++++++++++------
4 files changed, 59 insertions(+), 8 deletions(-)

diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h
index 52d83e888bd0..2780bed0c6b1 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -293,6 +293,8 @@ int iommufd_check_iova_range(struct io_pagetable *iopt,
struct iommufd_hw_pagetable {
struct iommufd_object obj;
struct iommu_domain *domain;
+ struct iommufd_fault *fault;
+ bool fault_capable : 1;
};

struct iommufd_hwpt_paging {
@@ -446,8 +448,17 @@ struct iommufd_fault {
struct wait_queue_head wait_queue;
};

+static inline struct iommufd_fault *
+iommufd_get_fault(struct iommufd_ucmd *ucmd, u32 id)
+{
+ return container_of(iommufd_get_object(ucmd->ictx, id,
+ IOMMUFD_OBJ_FAULT),
+ struct iommufd_fault, obj);
+}
+
int iommufd_fault_alloc(struct iommufd_ucmd *ucmd);
void iommufd_fault_destroy(struct iommufd_object *obj);
+int iommufd_fault_iopf_handler(struct iopf_group *group);

#ifdef CONFIG_IOMMUFD_TEST
int iommufd_test(struct iommufd_ucmd *ucmd);
diff --git a/include/uapi/linux/iommufd.h b/include/uapi/linux/iommufd.h
index c32d62b02306..7481cdd57027 100644
--- a/include/uapi/linux/iommufd.h
+++ b/include/uapi/linux/iommufd.h
@@ -357,10 +357,13 @@ struct iommu_vfio_ioas {
* the parent HWPT in a nesting configuration.
* @IOMMU_HWPT_ALLOC_DIRTY_TRACKING: Dirty tracking support for device IOMMU is
* enforced on device attachment
+ * @IOMMU_HWPT_FAULT_ID_VALID: The fault_id field of hwpt allocation data is
+ * valid.
*/
enum iommufd_hwpt_alloc_flags {
IOMMU_HWPT_ALLOC_NEST_PARENT = 1 << 0,
IOMMU_HWPT_ALLOC_DIRTY_TRACKING = 1 << 1,
+ IOMMU_HWPT_FAULT_ID_VALID = 1 << 2,
};

/**
@@ -411,6 +414,8 @@ enum iommu_hwpt_data_type {
* @__reserved: Must be 0
* @data_type: One of enum iommu_hwpt_data_type
* @data_len: Length of the type specific data
+ * @fault_id: The ID of IOMMUFD_FAULT object. Valid only if flags field of
+ * IOMMU_HWPT_FAULT_ID_VALID is set.
* @data_uptr: User pointer to the type specific data
*
* Explicitly allocate a hardware page table object. This is the same object
@@ -441,6 +446,7 @@ struct iommu_hwpt_alloc {
__u32 __reserved;
__u32 data_type;
__u32 data_len;
+ __u32 fault_id;
__aligned_u64 data_uptr;
};
#define IOMMU_HWPT_ALLOC _IO(IOMMUFD_TYPE, IOMMUFD_CMD_HWPT_ALLOC)
diff --git a/drivers/iommu/iommufd/fault.c b/drivers/iommu/iommufd/fault.c
index 9844a85feeb2..e752d1c49dde 100644
--- a/drivers/iommu/iommufd/fault.c
+++ b/drivers/iommu/iommufd/fault.c
@@ -253,3 +253,17 @@ int iommufd_fault_alloc(struct iommufd_ucmd *ucmd)

return rc;
}
+
+int iommufd_fault_iopf_handler(struct iopf_group *group)
+{
+ struct iommufd_hw_pagetable *hwpt = group->cookie->domain->fault_data;
+ struct iommufd_fault *fault = hwpt->fault;
+
+ mutex_lock(&fault->mutex);
+ list_add_tail(&group->node, &fault->deliver);
+ mutex_unlock(&fault->mutex);
+
+ wake_up_interruptible(&fault->wait_queue);
+
+ return 0;
+}
diff --git a/drivers/iommu/iommufd/hw_pagetable.c b/drivers/iommu/iommufd/hw_pagetable.c
index 3f3f1fa1a0a9..2703d5aea4f5 100644
--- a/drivers/iommu/iommufd/hw_pagetable.c
+++ b/drivers/iommu/iommufd/hw_pagetable.c
@@ -8,6 +8,15 @@
#include "../iommu-priv.h"
#include "iommufd_private.h"

+static void __iommufd_hwpt_destroy(struct iommufd_hw_pagetable *hwpt)
+{
+ if (hwpt->domain)
+ iommu_domain_free(hwpt->domain);
+
+ if (hwpt->fault)
+ iommufd_put_object(hwpt->fault->ictx, &hwpt->fault->obj);
+}
+
void iommufd_hwpt_paging_destroy(struct iommufd_object *obj)
{
struct iommufd_hwpt_paging *hwpt_paging =
@@ -22,9 +31,7 @@ void iommufd_hwpt_paging_destroy(struct iommufd_object *obj)
hwpt_paging->common.domain);
}

- if (hwpt_paging->common.domain)
- iommu_domain_free(hwpt_paging->common.domain);
-
+ __iommufd_hwpt_destroy(&hwpt_paging->common);
refcount_dec(&hwpt_paging->ioas->obj.users);
}

@@ -49,9 +56,7 @@ void iommufd_hwpt_nested_destroy(struct iommufd_object *obj)
struct iommufd_hwpt_nested *hwpt_nested =
container_of(obj, struct iommufd_hwpt_nested, common.obj);

- if (hwpt_nested->common.domain)
- iommu_domain_free(hwpt_nested->common.domain);
-
+ __iommufd_hwpt_destroy(&hwpt_nested->common);
refcount_dec(&hwpt_nested->parent->common.obj.users);
}

@@ -213,7 +218,8 @@ iommufd_hwpt_nested_alloc(struct iommufd_ctx *ictx,
struct iommufd_hw_pagetable *hwpt;
int rc;

- if (flags || !user_data->len || !ops->domain_alloc_user)
+ if ((flags & ~IOMMU_HWPT_FAULT_ID_VALID) ||
+ !user_data->len || !ops->domain_alloc_user)
return ERR_PTR(-EOPNOTSUPP);
if (parent->auto_domain || !parent->nest_parent)
return ERR_PTR(-EINVAL);
@@ -227,7 +233,7 @@ iommufd_hwpt_nested_alloc(struct iommufd_ctx *ictx,
refcount_inc(&parent->common.obj.users);
hwpt_nested->parent = parent;

- hwpt->domain = ops->domain_alloc_user(idev->dev, flags,
+ hwpt->domain = ops->domain_alloc_user(idev->dev, 0,
parent->common.domain, user_data);
if (IS_ERR(hwpt->domain)) {
rc = PTR_ERR(hwpt->domain);
@@ -307,6 +313,20 @@ int iommufd_hwpt_alloc(struct iommufd_ucmd *ucmd)
goto out_put_pt;
}

+ if (cmd->flags & IOMMU_HWPT_FAULT_ID_VALID) {
+ struct iommufd_fault *fault;
+
+ fault = iommufd_get_fault(ucmd, cmd->fault_id);
+ if (IS_ERR(fault)) {
+ rc = PTR_ERR(fault);
+ goto out_hwpt;
+ }
+ hwpt->fault = fault;
+ hwpt->domain->iopf_handler = iommufd_fault_iopf_handler;
+ hwpt->domain->fault_data = hwpt;
+ hwpt->fault_capable = true;
+ }
+
cmd->out_hwpt_id = hwpt->obj.id;
rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
if (rc)
--
2.34.1