[PATCH v2 1/3] iommufd/viommu: Allow associating a KVM VM fd with a vIOMMU

From: Aneesh Kumar K.V (Arm)

Date: Mon Mar 09 2026 - 07:18:36 EST


Add optional KVM association to IOMMU_VIOMMU_ALLOC by introducing
IOMMU_VIOMMU_KVM_FD and iommu_viommu_alloc::kvm_vm_fd.

When the flag is set, iommufd validates that kvm_vm_fd refers to a KVM
VM file and stores a referenced struct file in the vIOMMU object, so
later iommufd operations can safely resolve the owning VM.

This is preparatory plumbing for subsequent patches that bind TDI state
to the associated KVM VM.

The patch also switch file_is_kvm from EXPORT_SYMBOL_FOR_KVM_INTERNAL to
EXPORT_SYMBOL_GPL so that iommu module can use that.

Cc: Kevin Tian <kevin.tian@xxxxxxxxx>
Cc: Joerg Roedel <joro@xxxxxxxxxx>
Cc: Will Deacon <will@xxxxxxxxxx>
Cc: Bjorn Helgaas <helgaas@xxxxxxxxxx>
Cc: Jonathan Cameron <Jonathan.Cameron@xxxxxxxxxx>
Cc: Dan Williams <dan.j.williams@xxxxxxxxx>
Cc: Alexey Kardashevskiy <aik@xxxxxxx>
Cc: Samuel Ortiz <sameo@xxxxxxxxxxxx>
Cc: Xu Yilun <yilun.xu@xxxxxxxxxxxxxxx>
Cc: Jason Gunthorpe <jgg@xxxxxxxx>
Cc: Suzuki K Poulose <Suzuki.Poulose@xxxxxxx>
Cc: Steven Price <steven.price@xxxxxxx>
Signed-off-by: Aneesh Kumar K.V (Arm) <aneesh.kumar@xxxxxxxxxx>
---
drivers/iommu/iommufd/viommu.c | 54 +++++++++++++++++++++++++++++++++-
include/linux/iommufd.h | 3 ++
include/uapi/linux/iommufd.h | 13 +++++++-
virt/kvm/kvm_main.c | 2 +-
4 files changed, 69 insertions(+), 3 deletions(-)

diff --git a/drivers/iommu/iommufd/viommu.c b/drivers/iommu/iommufd/viommu.c
index 4081deda9b33..08f8930c86da 100644
--- a/drivers/iommu/iommufd/viommu.c
+++ b/drivers/iommu/iommufd/viommu.c
@@ -2,6 +2,45 @@
/* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES
*/
#include "iommufd_private.h"
+#include <linux/cleanup.h>
+
+#if IS_ENABLED(CONFIG_KVM)
+#include <linux/kvm_host.h>
+
+static int viommu_get_kvm(struct iommufd_viommu *viommu, int kvm_vm_fd)
+{
+ int rc = -EBADF;
+ struct file *filp __free(fput) = fget(kvm_vm_fd);
+
+ if (!file_is_kvm(filp))
+ return rc;
+
+ /* hold the kvm reference via file descriptor */
+ viommu->kvm_filp = no_free_ptr(filp);
+ return 0;
+}
+
+static void viommu_put_kvm(struct iommufd_viommu *viommu)
+{
+ if (!viommu->kvm_filp)
+ return;
+
+ fput(viommu->kvm_filp);
+ viommu->kvm_filp = NULL;
+}
+
+#else
+
+static inline int viommu_get_kvm(struct iommufd_viommu *viommu, int kvm_vm_fd)
+{
+ return -EOPNOTSUPP;
+}
+
+static inline void viommu_put_kvm(struct iommufd_viommu *viommu)
+{
+}
+
+#endif

void iommufd_viommu_destroy(struct iommufd_object *obj)
{
@@ -12,6 +51,8 @@ void iommufd_viommu_destroy(struct iommufd_object *obj)
viommu->ops->destroy(viommu);
refcount_dec(&viommu->hwpt->common.obj.users);
xa_destroy(&viommu->vdevs);
+
+ viommu_put_kvm(viommu);
}

int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd)
@@ -29,7 +70,9 @@ int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd)
size_t viommu_size;
int rc;

- if (cmd->flags || cmd->type == IOMMU_VIOMMU_TYPE_DEFAULT)
+ if (cmd->flags & ~IOMMU_VIOMMU_KVM_FD)
+ return -EOPNOTSUPP;
+ if (cmd->type == IOMMU_VIOMMU_TYPE_DEFAULT)
return -EOPNOTSUPP;

idev = iommufd_get_device(ucmd, cmd->dev_id);
@@ -100,8 +143,17 @@ int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd)
goto out_put_hwpt;
}

+ /* get the kvm details if specified. */
+ if (cmd->flags & IOMMU_VIOMMU_KVM_FD) {
+ rc = viommu_get_kvm(viommu, cmd->kvm_vm_fd);
+ if (rc)
+ goto out_put_hwpt;
+ }
+
cmd->out_viommu_id = viommu->obj.id;
rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
+ if (rc)
+ viommu_put_kvm(viommu);

out_put_hwpt:
iommufd_put_object(ucmd->ictx, &hwpt_paging->common.obj);
diff --git a/include/linux/iommufd.h b/include/linux/iommufd.h
index 6e7efe83bc5d..7c515d3c52db 100644
--- a/include/linux/iommufd.h
+++ b/include/linux/iommufd.h
@@ -12,6 +12,7 @@
#include <linux/refcount.h>
#include <linux/types.h>
#include <linux/xarray.h>
+#include <linux/file.h>
#include <uapi/linux/iommufd.h>

struct device;
@@ -58,6 +59,7 @@ struct iommufd_object {
unsigned int id;
};

+struct kvm;
struct iommufd_device *iommufd_device_bind(struct iommufd_ctx *ictx,
struct device *dev, u32 *id);
void iommufd_device_unbind(struct iommufd_device *idev);
@@ -101,6 +103,7 @@ struct iommufd_viommu {
struct iommufd_ctx *ictx;
struct iommu_device *iommu_dev;
struct iommufd_hwpt_paging *hwpt;
+ struct file *kvm_filp;

const struct iommufd_viommu_ops *ops;

diff --git a/include/uapi/linux/iommufd.h b/include/uapi/linux/iommufd.h
index 1dafbc552d37..b862c3e57133 100644
--- a/include/uapi/linux/iommufd.h
+++ b/include/uapi/linux/iommufd.h
@@ -1071,10 +1071,19 @@ struct iommu_viommu_tegra241_cmdqv {
__aligned_u64 out_vintf_mmap_length;
};

+/**
+ * define IOMMU_VIOMMU_KVM_FD - Flag indicating a valid KVM VM file descriptor
+ *
+ * Set this flag when allocating a viommu instance that should be associated
+ * with a specific KVM VM. If this flag is not provided,
+ * @iommu_viommu_alloc::kvm_vm_fd is ignored.
+ */
+#define IOMMU_VIOMMU_KVM_FD BIT(0)
+
/**
* struct iommu_viommu_alloc - ioctl(IOMMU_VIOMMU_ALLOC)
* @size: sizeof(struct iommu_viommu_alloc)
- * @flags: Must be 0
+ * @flags: Supported flags (IOMMU_VIOMMU_KVM_FD)
* @type: Type of the virtual IOMMU. Must be defined in enum iommu_viommu_type
* @dev_id: The device's physical IOMMU will be used to back the virtual IOMMU
* @hwpt_id: ID of a nesting parent HWPT to associate to
@@ -1082,6 +1091,7 @@ struct iommu_viommu_tegra241_cmdqv {
* @data_len: Length of the type specific data
* @__reserved: Must be 0
* @data_uptr: User pointer to a driver-specific virtual IOMMU data
+ * @kvm_vm_fd: KVM VM file descriptor when IOMMU_VIOMMU_KVM_FD is set
*
* Allocate a virtual IOMMU object, representing the underlying physical IOMMU's
* virtualization support that is a security-isolated slice of the real IOMMU HW
@@ -1105,6 +1115,7 @@ struct iommu_viommu_alloc {
__u32 data_len;
__u32 __reserved;
__aligned_u64 data_uptr;
+ __s32 kvm_vm_fd;
};
#define IOMMU_VIOMMU_ALLOC _IO(IOMMUFD_TYPE, IOMMUFD_CMD_VIOMMU_ALLOC)

diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 1bc1da66b4b0..f076c5a7a290 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -5481,7 +5481,7 @@ bool file_is_kvm(struct file *file)
{
return file && file->f_op == &kvm_vm_fops;
}
-EXPORT_SYMBOL_FOR_KVM_INTERNAL(file_is_kvm);
+EXPORT_SYMBOL_GPL(file_is_kvm);

static int kvm_dev_ioctl_create_vm(unsigned long type)
{
--
2.43.0