Re: [v6 PATCH 4/5] iommu: Support mm PASID 1:n with sva domains

From: Nicolin Chen
Date: Wed Oct 11 2023 - 15:34:17 EST


On Wed, Oct 11, 2023 at 09:26:12PM +0800, Tina Zhang wrote:
> On 10/11/23 20:39, Jason Gunthorpe wrote:
> > On Wed, Oct 11, 2023 at 02:51:31PM +0800, Tina Zhang wrote:
> >
> > > diff --git a/kernel/fork.c b/kernel/fork.c
> > > index 3b6d20dfb9a8..985403a7a747 100644
> > > --- a/kernel/fork.c
> > > +++ b/kernel/fork.c
> > > @@ -1277,7 +1277,6 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
> > > mm_init_cpumask(mm);
> > > mm_init_aio(mm);
> > > mm_init_owner(mm, p);
> > > - mm_pasid_init(mm);
> > > RCU_INIT_POINTER(mm->exe_file, NULL);
> > > mmu_notifier_subscriptions_init(mm);
> > > init_tlb_flush_pending(mm);
> >
> > Nicolin debugged his crash report last night and sent me the details.
> >
> > This hunk is the cause of the bug that Nicolin reported.
> >
> > The dup_mm() flow does:
> >
> > static struct mm_struct *dup_mm(struct task_struct *tsk,
> > struct mm_struct *oldmm)
> > {
> > struct mm_struct *mm;
> > int err;
> >
> > mm = allocate_mm();
> > if (!mm)
> > goto fail_nomem;
> >
> > memcpy(mm, oldmm, sizeof(*mm));
> >
> > if (!mm_init(mm, tsk, mm->user_ns))
> > goto fail_nomem;
> >
> > It is essential that mm_pasid_init() zero the new pointer otherwise,
> > due to the memcpy, after a fork two mm structs will point to the same
> > thing and one will UAF/doube free.
> Good catch.
>
> Thanks,
> -Tina
> >
> > Keep mm_pasid_init() and add zeroing the new pointer to it.

Yea, testing with this sees no more WARN_ON:

---------------------------------------------------------
diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index 3d782fd0f485..4bc3c49cdaf9 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -1208,2 +1208,6 @@ static inline bool tegra_dev_iommu_get_stream_id(struct device *dev, u32 *stream
#ifdef CONFIG_IOMMU_SVA
+static inline void mm_pasid_init(struct mm_struct *mm)
+{
+ mm->iommu_mm = NULL;
+}
static inline bool mm_valid_pasid(struct mm_struct *mm)
@@ -1240,2 +1244,3 @@ static inline u32 iommu_sva_get_pasid(struct iommu_sva *handle)
}
+static inline void mm_pasid_init(struct mm_struct *mm) {}
static inline bool mm_valid_pasid(struct mm_struct *mm) { return false; }
diff --git a/kernel/fork.c b/kernel/fork.c
index f06392dd1ca8..d2e12b6d2b18 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -1276,2 +1276,3 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
mm_init_owner(mm, p);
+ mm_pasid_init(mm);
RCU_INIT_POINTER(mm->exe_file, NULL);
---------------------------------------------------------

I'll confirm with v7 too.

Thanks
Nicolin