Re: [PATCH 09/11] RISC-V: drivers/iommu/riscv: Add SVA with PASID/ATS/PRI support.

From: Zong Li
Date: Mon Jul 31 2023 - 05:05:49 EST


On Thu, Jul 20, 2023 at 3:35 AM Tomasz Jeznach <tjeznach@xxxxxxxxxxxx> wrote:
>
> Introduces SVA (Shared Virtual Address) for RISC-V IOMMU, with
> ATS/PRI services for capable devices.
>
> Co-developed-by: Sebastien Boeuf <seb@xxxxxxxxxxxx>
> Signed-off-by: Sebastien Boeuf <seb@xxxxxxxxxxxx>
> Signed-off-by: Tomasz Jeznach <tjeznach@xxxxxxxxxxxx>
> ---
> drivers/iommu/riscv/iommu.c | 601 +++++++++++++++++++++++++++++++++++-
> drivers/iommu/riscv/iommu.h | 14 +
> 2 files changed, 610 insertions(+), 5 deletions(-)
>
> diff --git a/drivers/iommu/riscv/iommu.c b/drivers/iommu/riscv/iommu.c
> index 2ef6952a2109..6042c35be3ca 100644
> --- a/drivers/iommu/riscv/iommu.c
> +++ b/drivers/iommu/riscv/iommu.c
> @@ -384,6 +384,89 @@ static inline void riscv_iommu_cmd_iodir_set_did(struct riscv_iommu_command *cmd
> FIELD_PREP(RISCV_IOMMU_CMD_IODIR_DID, devid) | RISCV_IOMMU_CMD_IODIR_DV;
> }
>
> +static inline void riscv_iommu_cmd_iodir_set_pid(struct riscv_iommu_command *cmd,
> + unsigned pasid)
> +{
> + cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_IODIR_PID, pasid);
> +}
> +
> +static void riscv_iommu_cmd_ats_inval(struct riscv_iommu_command *cmd)
> +{
> + cmd->dword0 = FIELD_PREP(RISCV_IOMMU_CMD_OPCODE, RISCV_IOMMU_CMD_ATS_OPCODE) |
> + FIELD_PREP(RISCV_IOMMU_CMD_FUNC, RISCV_IOMMU_CMD_ATS_FUNC_INVAL);
> + cmd->dword1 = 0;
> +}
> +
> +static inline void riscv_iommu_cmd_ats_prgr(struct riscv_iommu_command *cmd)
> +{
> + cmd->dword0 = FIELD_PREP(RISCV_IOMMU_CMD_OPCODE, RISCV_IOMMU_CMD_ATS_OPCODE) |
> + FIELD_PREP(RISCV_IOMMU_CMD_FUNC, RISCV_IOMMU_CMD_ATS_FUNC_PRGR);
> + cmd->dword1 = 0;
> +}
> +
> +static void riscv_iommu_cmd_ats_set_rid(struct riscv_iommu_command *cmd, u32 rid)
> +{
> + cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_ATS_RID, rid);
> +}
> +
> +static void riscv_iommu_cmd_ats_set_pid(struct riscv_iommu_command *cmd, u32 pid)
> +{
> + cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_ATS_PID, pid) | RISCV_IOMMU_CMD_ATS_PV;
> +}
> +
> +static void riscv_iommu_cmd_ats_set_dseg(struct riscv_iommu_command *cmd, u8 seg)
> +{
> + cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_ATS_DSEG, seg) | RISCV_IOMMU_CMD_ATS_DSV;
> +}
> +
> +static void riscv_iommu_cmd_ats_set_payload(struct riscv_iommu_command *cmd, u64 payload)
> +{
> + cmd->dword1 = payload;
> +}
> +
> +/* Prepare the ATS invalidation payload */
> +static unsigned long riscv_iommu_ats_inval_payload(unsigned long start,
> + unsigned long end, bool global_inv)
> +{
> + size_t len = end - start + 1;
> + unsigned long payload = 0;
> +
> + /*
> + * PCI Express specification
> + * Section 10.2.3.2 Translation Range Size (S) Field
> + */
> + if (len < PAGE_SIZE)
> + len = PAGE_SIZE;
> + else
> + len = __roundup_pow_of_two(len);
> +
> + payload = (start & ~(len - 1)) | (((len - 1) >> 12) << 11);
> +
> + if (global_inv)
> + payload |= RISCV_IOMMU_CMD_ATS_INVAL_G;
> +
> + return payload;
> +}
> +
> +/* Prepare the ATS invalidation payload for all translations to be invalidated. */
> +static unsigned long riscv_iommu_ats_inval_all_payload(bool global_inv)
> +{
> + unsigned long payload = GENMASK_ULL(62, 11);
> +
> + if (global_inv)
> + payload |= RISCV_IOMMU_CMD_ATS_INVAL_G;
> +
> + return payload;
> +}
> +
> +/* Prepare the ATS "Page Request Group Response" payload */
> +static unsigned long riscv_iommu_ats_prgr_payload(u16 dest_id, u8 resp_code, u16 grp_idx)
> +{
> + return FIELD_PREP(RISCV_IOMMU_CMD_ATS_PRGR_DST_ID, dest_id) |
> + FIELD_PREP(RISCV_IOMMU_CMD_ATS_PRGR_RESP_CODE, resp_code) |
> + FIELD_PREP(RISCV_IOMMU_CMD_ATS_PRGR_PRG_INDEX, grp_idx);
> +}
> +
> /* TODO: Convert into lock-less MPSC implementation. */
> static bool riscv_iommu_post_sync(struct riscv_iommu_device *iommu,
> struct riscv_iommu_command *cmd, bool sync)
> @@ -460,6 +543,16 @@ static bool riscv_iommu_iodir_inv_devid(struct riscv_iommu_device *iommu, unsign
> return riscv_iommu_post(iommu, &cmd);
> }
>
> +static bool riscv_iommu_iodir_inv_pasid(struct riscv_iommu_device *iommu,
> + unsigned devid, unsigned pasid)
> +{
> + struct riscv_iommu_command cmd;
> + riscv_iommu_cmd_iodir_inval_pdt(&cmd);
> + riscv_iommu_cmd_iodir_set_did(&cmd, devid);
> + riscv_iommu_cmd_iodir_set_pid(&cmd, pasid);
> + return riscv_iommu_post(iommu, &cmd);
> +}
> +
> static bool riscv_iommu_iofence_sync(struct riscv_iommu_device *iommu)
> {
> struct riscv_iommu_command cmd;
> @@ -467,6 +560,62 @@ static bool riscv_iommu_iofence_sync(struct riscv_iommu_device *iommu)
> return riscv_iommu_post_sync(iommu, &cmd, true);
> }
>
> +static void riscv_iommu_mm_invalidate(struct mmu_notifier *mn,
> + struct mm_struct *mm, unsigned long start,
> + unsigned long end)
> +{
> + struct riscv_iommu_command cmd;
> + struct riscv_iommu_endpoint *endpoint;
> + struct riscv_iommu_domain *domain =
> + container_of(mn, struct riscv_iommu_domain, mn);
> + unsigned long iova;
> + /*
> + * The mm_types defines vm_end as the first byte after the end address,
> + * different from IOMMU subsystem using the last address of an address
> + * range. So do a simple translation here by updating what end means.
> + */
> + unsigned long payload = riscv_iommu_ats_inval_payload(start, end - 1, true);
> +
> + riscv_iommu_cmd_inval_vma(&cmd);
> + riscv_iommu_cmd_inval_set_gscid(&cmd, 0);
> + riscv_iommu_cmd_inval_set_pscid(&cmd, domain->pscid);
> + if (end > start) {
> + /* Cover only the range that is needed */
> + for (iova = start; iova < end; iova += PAGE_SIZE) {
> + riscv_iommu_cmd_inval_set_addr(&cmd, iova);
> + riscv_iommu_post(domain->iommu, &cmd);
> + }
> + } else {
> + riscv_iommu_post(domain->iommu, &cmd);
> + }
> +
> + riscv_iommu_iofence_sync(domain->iommu);
> +
> + /* ATS invalidation for every device and for specific translation range. */
> + list_for_each_entry(endpoint, &domain->endpoints, domain) {
> + if (!endpoint->pasid_enabled)
> + continue;
> +
> + riscv_iommu_cmd_ats_inval(&cmd);
> + riscv_iommu_cmd_ats_set_dseg(&cmd, endpoint->domid);
> + riscv_iommu_cmd_ats_set_rid(&cmd, endpoint->devid);
> + riscv_iommu_cmd_ats_set_pid(&cmd, domain->pasid);
> + riscv_iommu_cmd_ats_set_payload(&cmd, payload);
> + riscv_iommu_post(domain->iommu, &cmd);
> + }
> + riscv_iommu_iofence_sync(domain->iommu);
> +}
> +
> +static void riscv_iommu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
> +{
> + /* TODO: removed from notifier, cleanup PSCID mapping, flush IOTLB */
> +}
> +
> +static const struct mmu_notifier_ops riscv_iommu_mmuops = {
> + .release = riscv_iommu_mm_release,
> + .invalidate_range = riscv_iommu_mm_invalidate,
> +};
> +
> /* Command queue primary interrupt handler */
> static irqreturn_t riscv_iommu_cmdq_irq_check(int irq, void *data)
> {
> @@ -608,6 +757,128 @@ static void riscv_iommu_add_device(struct riscv_iommu_device *iommu, struct devi
> mutex_unlock(&iommu->eps_mutex);
> }
>
> +/*
> + * Get device reference based on device identifier (requester id).
> + * Decrement reference count with put_device() call.
> + */
> +static struct device *riscv_iommu_get_device(struct riscv_iommu_device *iommu,
> + unsigned devid)
> +{
> + struct rb_node *node;
> + struct riscv_iommu_endpoint *ep;
> + struct device *dev = NULL;
> +
> + mutex_lock(&iommu->eps_mutex);
> +
> + node = iommu->eps.rb_node;
> + while (node && !dev) {
> + ep = rb_entry(node, struct riscv_iommu_endpoint, node);
> + if (ep->devid < devid)
> + node = node->rb_right;
> + else if (ep->devid > devid)
> + node = node->rb_left;
> + else
> + dev = get_device(ep->dev);
> + }
> +
> + mutex_unlock(&iommu->eps_mutex);
> +
> + return dev;
> +}
> +
> +static int riscv_iommu_ats_prgr(struct device *dev, struct iommu_page_response *msg)
> +{
> + struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> + struct riscv_iommu_command cmd;
> + u8 resp_code;
> + unsigned long payload;
> +
> + switch (msg->code) {
> + case IOMMU_PAGE_RESP_SUCCESS:
> + resp_code = 0b0000;
> + break;
> + case IOMMU_PAGE_RESP_INVALID:
> + resp_code = 0b0001;
> + break;
> + case IOMMU_PAGE_RESP_FAILURE:
> + resp_code = 0b1111;
> + break;
> + }
> + payload = riscv_iommu_ats_prgr_payload(ep->devid, resp_code, msg->grpid);
> +
> + /* ATS Page Request Group Response */
> + riscv_iommu_cmd_ats_prgr(&cmd);
> + riscv_iommu_cmd_ats_set_dseg(&cmd, ep->domid);
> + riscv_iommu_cmd_ats_set_rid(&cmd, ep->devid);
> + if (msg->flags & IOMMU_PAGE_RESP_PASID_VALID)
> + riscv_iommu_cmd_ats_set_pid(&cmd, msg->pasid);
> + riscv_iommu_cmd_ats_set_payload(&cmd, payload);
> + riscv_iommu_post(ep->iommu, &cmd);
> +
> + return 0;
> +}
> +
> +static void riscv_iommu_page_request(struct riscv_iommu_device *iommu,
> + struct riscv_iommu_pq_record *req)
> +{
> + struct iommu_fault_event event = { 0 };
> + struct iommu_fault_page_request *prm = &event.fault.prm;
> + int ret;
> + struct device *dev;
> + unsigned devid = FIELD_GET(RISCV_IOMMU_PREQ_HDR_DID, req->hdr);
> +
> + /* Ignore PGR Stop marker. */
> + if ((req->payload & RISCV_IOMMU_PREQ_PAYLOAD_M) == RISCV_IOMMU_PREQ_PAYLOAD_L)
> + return;
> +
> + dev = riscv_iommu_get_device(iommu, devid);
> + if (!dev) {
> + /* TODO: Handle invalid page request */
> + return;
> + }
> +
> + event.fault.type = IOMMU_FAULT_PAGE_REQ;
> +
> + if (req->payload & RISCV_IOMMU_PREQ_PAYLOAD_L)
> + prm->flags |= IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE;
> + if (req->payload & RISCV_IOMMU_PREQ_PAYLOAD_W)
> + prm->perm |= IOMMU_FAULT_PERM_WRITE;
> + if (req->payload & RISCV_IOMMU_PREQ_PAYLOAD_R)
> + prm->perm |= IOMMU_FAULT_PERM_READ;
> +
> + prm->grpid = FIELD_GET(RISCV_IOMMU_PREQ_PRG_INDEX, req->payload);
> + prm->addr = FIELD_GET(RISCV_IOMMU_PREQ_UADDR, req->payload) << PAGE_SHIFT;
> +
> + if (req->hdr & RISCV_IOMMU_PREQ_HDR_PV) {
> + prm->flags |= IOMMU_FAULT_PAGE_REQUEST_PASID_VALID;
> + /* TODO: where to find this bit */
> + prm->flags |= IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID;
> + prm->pasid = FIELD_GET(RISCV_IOMMU_PREQ_HDR_PID, req->hdr);
> + }
> +
> + ret = iommu_report_device_fault(dev, &event);
> + if (ret) {
> + struct iommu_page_response resp = {
> + .grpid = prm->grpid,
> + .code = IOMMU_PAGE_RESP_FAILURE,
> + };
> + if (prm->flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID) {
> + resp.flags |= IOMMU_PAGE_RESP_PASID_VALID;
> + resp.pasid = prm->pasid;
> + }
> + riscv_iommu_ats_prgr(dev, &resp);
> + }
> +
> + put_device(dev);
> +}
> +
> +static int riscv_iommu_page_response(struct device *dev,
> + struct iommu_fault_event *evt,
> + struct iommu_page_response *msg)
> +{
> + return riscv_iommu_ats_prgr(dev, msg);
> +}
> +
> /* Page request interface queue primary interrupt handler */
> static irqreturn_t riscv_iommu_priq_irq_check(int irq, void *data)
> {
> @@ -626,7 +897,7 @@ static irqreturn_t riscv_iommu_priq_process(int irq, void *data)
> struct riscv_iommu_queue *q = (struct riscv_iommu_queue *)data;
> struct riscv_iommu_device *iommu;
> struct riscv_iommu_pq_record *requests;
> - unsigned cnt, idx, ctrl;
> + unsigned cnt, len, idx, ctrl;
>
> iommu = container_of(q, struct riscv_iommu_device, priq);
> requests = (struct riscv_iommu_pq_record *)q->base;
> @@ -649,7 +920,8 @@ static irqreturn_t riscv_iommu_priq_process(int irq, void *data)
> cnt = riscv_iommu_queue_consume(iommu, q, &idx);
> if (!cnt)
> break;
> - dev_warn(iommu->dev, "unexpected %u page requests\n", cnt);
> + for (len = 0; len < cnt; idx++, len++)
> + riscv_iommu_page_request(iommu, &requests[idx]);
> riscv_iommu_queue_release(iommu, q, cnt);
> } while (1);
>
> @@ -660,6 +932,169 @@ static irqreturn_t riscv_iommu_priq_process(int irq, void *data)
> * Endpoint management
> */
>
> +/* Endpoint features/capabilities */
> +static void riscv_iommu_disable_ep(struct riscv_iommu_endpoint *ep)
> +{
> + struct pci_dev *pdev;
> +
> + if (!dev_is_pci(ep->dev))
> + return;
> +
> + pdev = to_pci_dev(ep->dev);
> +
> + if (ep->pasid_enabled) {
> + pci_disable_ats(pdev);
> + pci_disable_pri(pdev);
> + pci_disable_pasid(pdev);
> + ep->pasid_enabled = false;
> + }
> +}
> +
> +static void riscv_iommu_enable_ep(struct riscv_iommu_endpoint *ep)
> +{
> + int rc, feat, num;
> + struct pci_dev *pdev;
> + struct device *dev = ep->dev;
> +
> + if (!dev_is_pci(dev))
> + return;
> +
> + if (!ep->iommu->iommu.max_pasids)
> + return;
> +
> + pdev = to_pci_dev(dev);
> +
> + if (!pci_ats_supported(pdev))
> + return;
> +
> + if (!pci_pri_supported(pdev))
> + return;
> +
> + feat = pci_pasid_features(pdev);
> + if (feat < 0)
> + return;
> +
> + num = pci_max_pasids(pdev);
> + if (!num) {
> + dev_warn(dev, "Can't enable PASID (num: %d)\n", num);
> + return;
> + }
> +
> + if (num > ep->iommu->iommu.max_pasids)
> + num = ep->iommu->iommu.max_pasids;
> +
> + rc = pci_enable_pasid(pdev, feat);
> + if (rc) {
> + dev_warn(dev, "Can't enable PASID (rc: %d)\n", rc);
> + return;
> + }
> +
> + rc = pci_reset_pri(pdev);
> + if (rc) {
> + dev_warn(dev, "Can't reset PRI (rc: %d)\n", rc);
> + pci_disable_pasid(pdev);
> + return;
> + }
> +
> + /* TODO: Get supported PRI queue length, hard-code to 32 entries */
> + rc = pci_enable_pri(pdev, 32);
> + if (rc) {
> + dev_warn(dev, "Can't enable PRI (rc: %d)\n", rc);
> + pci_disable_pasid(pdev);
> + return;
> + }
> +
> + rc = pci_enable_ats(pdev, PAGE_SHIFT);
> + if (rc) {
> + dev_warn(dev, "Can't enable ATS (rc: %d)\n", rc);
> + pci_disable_pri(pdev);
> + pci_disable_pasid(pdev);
> + return;
> + }
> +
> + ep->pc = (struct riscv_iommu_pc *)get_zeroed_page(GFP_KERNEL);
> + if (!ep->pc) {
> + pci_disable_ats(pdev);
> + pci_disable_pri(pdev);
> + pci_disable_pasid(pdev);
> + return;
> + }
> +
> + ep->pasid_enabled = true;
> + ep->pasid_feat = feat;
> + ep->pasid_bits = ilog2(num);
> +
> + dev_dbg(ep->dev, "PASID/ATS support enabled, %d bits\n", ep->pasid_bits);
> +}
> +
> +static int riscv_iommu_enable_sva(struct device *dev)
> +{
> + int ret;
> + struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> +
> + if (!ep || !ep->iommu || !ep->iommu->pq_work)
> + return -EINVAL;
> +
> + if (!ep->pasid_enabled)
> + return -ENODEV;
> +
> + ret = iopf_queue_add_device(ep->iommu->pq_work, dev);
> + if (ret)
> + return ret;
> +
> + return iommu_register_device_fault_handler(dev, iommu_queue_iopf, dev);
> +}
> +
> +static int riscv_iommu_disable_sva(struct device *dev)
> +{
> + int ret;
> + struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> +
> + ret = iommu_unregister_device_fault_handler(dev);
> + if (!ret)
> + ret = iopf_queue_remove_device(ep->iommu->pq_work, dev);
> +
> + return ret;
> +}
> +
> +static int riscv_iommu_enable_iopf(struct device *dev)
> +{
> + struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> +
> + if (ep && ep->pasid_enabled)
> + return 0;
> +
> + return -EINVAL;
> +}
> +
> +static int riscv_iommu_dev_enable_feat(struct device *dev, enum iommu_dev_features feat)
> +{
> + switch (feat) {
> + case IOMMU_DEV_FEAT_IOPF:
> + return riscv_iommu_enable_iopf(dev);
> +
> + case IOMMU_DEV_FEAT_SVA:
> + return riscv_iommu_enable_sva(dev);
> +
> + default:
> + return -ENODEV;
> + }
> +}
> +
> +static int riscv_iommu_dev_disable_feat(struct device *dev, enum iommu_dev_features feat)
> +{
> + switch (feat) {
> + case IOMMU_DEV_FEAT_IOPF:
> + return 0;
> +
> + case IOMMU_DEV_FEAT_SVA:
> + return riscv_iommu_disable_sva(dev);
> +
> + default:
> + return -ENODEV;
> + }
> +}
> +
> static int riscv_iommu_of_xlate(struct device *dev, struct of_phandle_args *args)
> {
> return iommu_fwspec_add_ids(dev, args->args, 1);
> @@ -812,6 +1247,7 @@ static struct iommu_device *riscv_iommu_probe_device(struct device *dev)
>
> dev_iommu_priv_set(dev, ep);
> riscv_iommu_add_device(iommu, dev);
> + riscv_iommu_enable_ep(ep);
>
> return &iommu->iommu;
> }
> @@ -843,6 +1279,8 @@ static void riscv_iommu_release_device(struct device *dev)
> riscv_iommu_iodir_inv_devid(iommu, ep->devid);
> }
>
> + riscv_iommu_disable_ep(ep);
> +
> /* Remove endpoint from IOMMU tracking structures */
> mutex_lock(&iommu->eps_mutex);
> rb_erase(&ep->node, &iommu->eps);
> @@ -878,7 +1316,8 @@ static struct iommu_domain *riscv_iommu_domain_alloc(unsigned type)
> type != IOMMU_DOMAIN_DMA_FQ &&
> type != IOMMU_DOMAIN_UNMANAGED &&
> type != IOMMU_DOMAIN_IDENTITY &&
> - type != IOMMU_DOMAIN_BLOCKED)
> + type != IOMMU_DOMAIN_BLOCKED &&
> + type != IOMMU_DOMAIN_SVA)
> return NULL;
>
> domain = kzalloc(sizeof(*domain), GFP_KERNEL);
> @@ -906,6 +1345,9 @@ static void riscv_iommu_domain_free(struct iommu_domain *iommu_domain)
> pr_warn("IOMMU domain is not empty!\n");
> }
>
> + if (domain->mn.ops && iommu_domain->mm)
> + mmu_notifier_unregister(&domain->mn, iommu_domain->mm);
> +
> if (domain->pgtbl.cookie)
> free_io_pgtable_ops(&domain->pgtbl.ops);
>
> @@ -1023,14 +1465,29 @@ static int riscv_iommu_attach_dev(struct iommu_domain *iommu_domain, struct devi
> */
> val = FIELD_PREP(RISCV_IOMMU_DC_TA_PSCID, domain->pscid);
>
> - dc->ta = cpu_to_le64(val);
> - dc->fsc = cpu_to_le64(riscv_iommu_domain_atp(domain));
> + if (ep->pasid_enabled) {
> + ep->pc[0].ta = cpu_to_le64(val | RISCV_IOMMU_PC_TA_V);
> + ep->pc[0].fsc = cpu_to_le64(riscv_iommu_domain_atp(domain));
> + dc->ta = 0;
> + dc->fsc = cpu_to_le64(virt_to_pfn(ep->pc) |
> + FIELD_PREP(RISCV_IOMMU_DC_FSC_MODE, RISCV_IOMMU_DC_FSC_PDTP_MODE_PD8));

Could I know why we determinate to use PD8 directly? Rather than PD17 or PD20.

> + } else {
> + dc->ta = cpu_to_le64(val);
> + dc->fsc = cpu_to_le64(riscv_iommu_domain_atp(domain));
> + }
>
> wmb();
>
> /* Mark device context as valid, synchronise device context cache. */
> val = RISCV_IOMMU_DC_TC_V;
>
> + if (ep->pasid_enabled) {
> + val |= RISCV_IOMMU_DC_TC_EN_ATS |
> + RISCV_IOMMU_DC_TC_EN_PRI |
> + RISCV_IOMMU_DC_TC_DPE |
> + RISCV_IOMMU_DC_TC_PDTV;
> + }
> +
> if (ep->iommu->cap & RISCV_IOMMU_CAP_AMO) {
> val |= RISCV_IOMMU_DC_TC_GADE |
> RISCV_IOMMU_DC_TC_SADE;
> @@ -1051,13 +1508,107 @@ static int riscv_iommu_attach_dev(struct iommu_domain *iommu_domain, struct devi
> return 0;
> }
>
> +static int riscv_iommu_set_dev_pasid(struct iommu_domain *iommu_domain,
> + struct device *dev, ioasid_t pasid)
> +{
> + struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
> + struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> + u64 ta, fsc;
> +
> + if (!iommu_domain || !iommu_domain->mm)
> + return -EINVAL;
> +
> + /* Driver uses TC.DPE mode, PASID #0 is incorrect. */
> + if (pasid == 0)
> + return -EINVAL;
> +
> + /* Incorrect domain identifier */
> + if ((int)domain->pscid < 0)
> + return -ENOMEM;
> +
> + /* Process Context table should be set for pasid enabled endpoints. */
> + if (!ep || !ep->pasid_enabled || !ep->dc || !ep->pc)
> + return -ENODEV;
> +
> + domain->pasid = pasid;
> + domain->iommu = ep->iommu;
> + domain->mn.ops = &riscv_iommu_mmuops;
> +
> + /* register mm notifier */
> + if (mmu_notifier_register(&domain->mn, iommu_domain->mm))
> + return -ENODEV;
> +
> + /* TODO: get SXL value for the process, use 32 bit or SATP mode */
> + fsc = virt_to_pfn(iommu_domain->mm->pgd) | satp_mode;
> + ta = RISCV_IOMMU_PC_TA_V | FIELD_PREP(RISCV_IOMMU_PC_TA_PSCID, domain->pscid);
> +
> + fsc = le64_to_cpu(xchg_relaxed(&(ep->pc[pasid].fsc), cpu_to_le64(fsc)));
> + ta = le64_to_cpu(xchg_relaxed(&(ep->pc[pasid].ta), cpu_to_le64(ta)));
> +
> + wmb();
> +
> + if (ta & RISCV_IOMMU_PC_TA_V) {
> + riscv_iommu_iodir_inv_pasid(ep->iommu, ep->devid, pasid);
> + riscv_iommu_iofence_sync(ep->iommu);
> + }
> +
> + dev_info(dev, "domain type %d attached w/ PSCID %u PASID %u\n",
> + domain->domain.type, domain->pscid, domain->pasid);
> +
> + return 0;
> +}
> +
> +static void riscv_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid)
> +{
> + struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> + struct riscv_iommu_command cmd;
> + unsigned long payload = riscv_iommu_ats_inval_all_payload(false);
> + u64 ta;
> +
> + /* invalidate TA.V */
> + ta = le64_to_cpu(xchg_relaxed(&(ep->pc[pasid].ta), 0));
> +
> + wmb();
> +
> + dev_info(dev, "domain removed w/ PSCID %u PASID %u\n",
> + (unsigned)FIELD_GET(RISCV_IOMMU_PC_TA_PSCID, ta), pasid);
> +
> + /* 1. invalidate PDT entry */
> + riscv_iommu_iodir_inv_pasid(ep->iommu, ep->devid, pasid);
> +
> + /* 2. invalidate all matching IOATC entries (if PASID was valid) */
> + if (ta & RISCV_IOMMU_PC_TA_V) {
> + riscv_iommu_cmd_inval_vma(&cmd);
> + riscv_iommu_cmd_inval_set_gscid(&cmd, 0);
> + riscv_iommu_cmd_inval_set_pscid(&cmd,
> + FIELD_GET(RISCV_IOMMU_PC_TA_PSCID, ta));
> + riscv_iommu_post(ep->iommu, &cmd);
> + }
> +
> + /* 3. Wait IOATC flush to happen */
> + riscv_iommu_iofence_sync(ep->iommu);
> +
> + /* 4. ATS invalidation */
> + riscv_iommu_cmd_ats_inval(&cmd);
> + riscv_iommu_cmd_ats_set_dseg(&cmd, ep->domid);
> + riscv_iommu_cmd_ats_set_rid(&cmd, ep->devid);
> + riscv_iommu_cmd_ats_set_pid(&cmd, pasid);
> + riscv_iommu_cmd_ats_set_payload(&cmd, payload);
> + riscv_iommu_post(ep->iommu, &cmd);
> +
> + /* 5. Wait DevATC flush to happen */
> + riscv_iommu_iofence_sync(ep->iommu);
> +}
> +
> static void riscv_iommu_flush_iotlb_range(struct iommu_domain *iommu_domain,
> unsigned long *start, unsigned long *end,
> size_t *pgsize)
> {
> struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
> struct riscv_iommu_command cmd;
> + struct riscv_iommu_endpoint *endpoint;
> unsigned long iova;
> + unsigned long payload;
>
> if (domain->mode == RISCV_IOMMU_DC_FSC_MODE_BARE)
> return;
> @@ -1065,6 +1616,12 @@ static void riscv_iommu_flush_iotlb_range(struct iommu_domain *iommu_domain,
> /* Domain not attached to an IOMMU! */
> BUG_ON(!domain->iommu);
>
> + if (start && end) {
> + payload = riscv_iommu_ats_inval_payload(*start, *end, true);
> + } else {
> + payload = riscv_iommu_ats_inval_all_payload(true);
> + }
> +
> riscv_iommu_cmd_inval_vma(&cmd);
> riscv_iommu_cmd_inval_set_pscid(&cmd, domain->pscid);
>
> @@ -1078,6 +1635,20 @@ static void riscv_iommu_flush_iotlb_range(struct iommu_domain *iommu_domain,
> riscv_iommu_post(domain->iommu, &cmd);
> }
> riscv_iommu_iofence_sync(domain->iommu);
> +
> + /* ATS invalidation for every device and for every translation */
> + list_for_each_entry(endpoint, &domain->endpoints, domain) {
> + if (!endpoint->pasid_enabled)
> + continue;
> +
> + riscv_iommu_cmd_ats_inval(&cmd);
> + riscv_iommu_cmd_ats_set_dseg(&cmd, endpoint->domid);
> + riscv_iommu_cmd_ats_set_rid(&cmd, endpoint->devid);
> + riscv_iommu_cmd_ats_set_pid(&cmd, domain->pasid);
> + riscv_iommu_cmd_ats_set_payload(&cmd, payload);
> + riscv_iommu_post(domain->iommu, &cmd);
> + }
> + riscv_iommu_iofence_sync(domain->iommu);
> }
>
> static void riscv_iommu_flush_iotlb_all(struct iommu_domain *iommu_domain)
> @@ -1310,6 +1881,7 @@ static int riscv_iommu_enable(struct riscv_iommu_device *iommu, unsigned request
> static const struct iommu_domain_ops riscv_iommu_domain_ops = {
> .free = riscv_iommu_domain_free,
> .attach_dev = riscv_iommu_attach_dev,
> + .set_dev_pasid = riscv_iommu_set_dev_pasid,
> .map_pages = riscv_iommu_map_pages,
> .unmap_pages = riscv_iommu_unmap_pages,
> .iova_to_phys = riscv_iommu_iova_to_phys,
> @@ -1326,9 +1898,13 @@ static const struct iommu_ops riscv_iommu_ops = {
> .probe_device = riscv_iommu_probe_device,
> .probe_finalize = riscv_iommu_probe_finalize,
> .release_device = riscv_iommu_release_device,
> + .remove_dev_pasid = riscv_iommu_remove_dev_pasid,
> .device_group = riscv_iommu_device_group,
> .get_resv_regions = riscv_iommu_get_resv_regions,
> .of_xlate = riscv_iommu_of_xlate,
> + .dev_enable_feat = riscv_iommu_dev_enable_feat,
> + .dev_disable_feat = riscv_iommu_dev_disable_feat,
> + .page_response = riscv_iommu_page_response,
> .default_domain_ops = &riscv_iommu_domain_ops,
> };
>
> @@ -1340,6 +1916,7 @@ void riscv_iommu_remove(struct riscv_iommu_device *iommu)
> riscv_iommu_queue_free(iommu, &iommu->cmdq);
> riscv_iommu_queue_free(iommu, &iommu->fltq);
> riscv_iommu_queue_free(iommu, &iommu->priq);
> + iopf_queue_free(iommu->pq_work);
> }
>
> int riscv_iommu_init(struct riscv_iommu_device *iommu)
> @@ -1362,6 +1939,12 @@ int riscv_iommu_init(struct riscv_iommu_device *iommu)
> }
> #endif
>
> + if (iommu->cap & RISCV_IOMMU_CAP_PD20)
> + iommu->iommu.max_pasids = 1u << 20;
> + else if (iommu->cap & RISCV_IOMMU_CAP_PD17)
> + iommu->iommu.max_pasids = 1u << 17;
> + else if (iommu->cap & RISCV_IOMMU_CAP_PD8)
> + iommu->iommu.max_pasids = 1u << 8;
> /*
> * Assign queue lengths from module parameters if not already
> * set on the device tree.
> @@ -1387,6 +1970,13 @@ int riscv_iommu_init(struct riscv_iommu_device *iommu)
> goto fail;
> if (!(iommu->cap & RISCV_IOMMU_CAP_ATS))
> goto no_ats;
> + /* PRI functionally depends on ATS’s capabilities. */
> + iommu->pq_work = iopf_queue_alloc(dev_name(dev));
> + if (!iommu->pq_work) {
> + dev_err(dev, "failed to allocate iopf queue\n");
> + ret = -ENOMEM;
> + goto fail;
> + }
>
> ret = riscv_iommu_queue_init(iommu, RISCV_IOMMU_PAGE_REQUEST_QUEUE);
> if (ret)
> @@ -1424,5 +2014,6 @@ int riscv_iommu_init(struct riscv_iommu_device *iommu)
> riscv_iommu_queue_free(iommu, &iommu->priq);
> riscv_iommu_queue_free(iommu, &iommu->fltq);
> riscv_iommu_queue_free(iommu, &iommu->cmdq);
> + iopf_queue_free(iommu->pq_work);
> return ret;
> }
> diff --git a/drivers/iommu/riscv/iommu.h b/drivers/iommu/riscv/iommu.h
> index fe32a4eff14e..83e8d00fd0f8 100644
> --- a/drivers/iommu/riscv/iommu.h
> +++ b/drivers/iommu/riscv/iommu.h
> @@ -17,9 +17,11 @@
> #include <linux/iova.h>
> #include <linux/io.h>
> #include <linux/idr.h>
> +#include <linux/mmu_notifier.h>
> #include <linux/list.h>
> #include <linux/iommu.h>
> #include <linux/io-pgtable.h>
> +#include <linux/mmu_notifier.h>

You include the mmu_notifier.h twice in this header

>
> #include "iommu-bits.h"
>
> @@ -76,6 +78,9 @@ struct riscv_iommu_device {
> unsigned ddt_mode;
> bool ddtp_in_iomem;
>
> + /* I/O page fault queue */
> + struct iopf_queue *pq_work;
> +
> /* hardware queues */
> struct riscv_iommu_queue cmdq;
> struct riscv_iommu_queue fltq;
> @@ -91,11 +96,14 @@ struct riscv_iommu_domain {
> struct io_pgtable pgtbl;
>
> struct list_head endpoints;
> + struct list_head notifiers;
> struct mutex lock;
> + struct mmu_notifier mn;
> struct riscv_iommu_device *iommu;
>
> unsigned mode; /* RIO_ATP_MODE_* enum */
> unsigned pscid; /* RISC-V IOMMU PSCID */
> + ioasid_t pasid; /* IOMMU_DOMAIN_SVA: Cached PASID */
>
> pgd_t *pgd_root; /* page table root pointer */
> };
> @@ -107,10 +115,16 @@ struct riscv_iommu_endpoint {
> unsigned domid; /* PCI domain number, segment */
> struct rb_node node; /* device tracking node (lookup by devid) */
> struct riscv_iommu_dc *dc; /* device context pointer */
> + struct riscv_iommu_pc *pc; /* process context root, valid if pasid_enabled is true */
> struct riscv_iommu_device *iommu; /* parent iommu device */
>
> struct mutex lock;
> struct list_head domain; /* endpoint attached managed domain */
> +
> + /* end point info bits */
> + unsigned pasid_bits;
> + unsigned pasid_feat;
> + bool pasid_enabled;
> };
>
> /* Helper functions and macros */
> --
> 2.34.1
>
>
> _______________________________________________
> linux-riscv mailing list
> linux-riscv@xxxxxxxxxxxxxxxxxxx
> http://lists.infradead.org/mailman/listinfo/linux-riscv