Re: [PATCH v2 02/11] mm/hmm: use reference counting for HMM struct v2

From: Ira Weiny
Date: Thu Mar 28 2019 - 15:08:30 EST


On Mon, Mar 25, 2019 at 10:40:02AM -0400, Jerome Glisse wrote:
> From: Jérôme Glisse <jglisse@xxxxxxxxxx>
>
> Every time i read the code to check that the HMM structure does not
> vanish before it should thanks to the many lock protecting its removal
> i get a headache. Switch to reference counting instead it is much
> easier to follow and harder to break. This also remove some code that
> is no longer needed with refcounting.
>
> Changes since v1:
> - removed bunch of useless check (if API is use with bogus argument
> better to fail loudly so user fix their code)
> - s/hmm_get/mm_get_hmm/
>
> Signed-off-by: Jérôme Glisse <jglisse@xxxxxxxxxx>
> Reviewed-by: Ralph Campbell <rcampbell@xxxxxxxxxx>
> Cc: John Hubbard <jhubbard@xxxxxxxxxx>
> Cc: Andrew Morton <akpm@xxxxxxxxxxxxxxxxxxxx>
> Cc: Dan Williams <dan.j.williams@xxxxxxxxx>
> ---
> include/linux/hmm.h | 2 +
> mm/hmm.c | 170 ++++++++++++++++++++++++++++----------------
> 2 files changed, 112 insertions(+), 60 deletions(-)
>
> diff --git a/include/linux/hmm.h b/include/linux/hmm.h
> index ad50b7b4f141..716fc61fa6d4 100644
> --- a/include/linux/hmm.h
> +++ b/include/linux/hmm.h
> @@ -131,6 +131,7 @@ enum hmm_pfn_value_e {
> /*
> * struct hmm_range - track invalidation lock on virtual address range
> *
> + * @hmm: the core HMM structure this range is active against
> * @vma: the vm area struct for the range
> * @list: all range lock are on a list
> * @start: range virtual start address (inclusive)
> @@ -142,6 +143,7 @@ enum hmm_pfn_value_e {
> * @valid: pfns array did not change since it has been fill by an HMM function
> */
> struct hmm_range {
> + struct hmm *hmm;
> struct vm_area_struct *vma;
> struct list_head list;
> unsigned long start;
> diff --git a/mm/hmm.c b/mm/hmm.c
> index fe1cd87e49ac..306e57f7cded 100644
> --- a/mm/hmm.c
> +++ b/mm/hmm.c
> @@ -50,6 +50,7 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops;
> */
> struct hmm {
> struct mm_struct *mm;
> + struct kref kref;
> spinlock_t lock;
> struct list_head ranges;
> struct list_head mirrors;
> @@ -57,6 +58,16 @@ struct hmm {
> struct rw_semaphore mirrors_sem;
> };
>
> +static inline struct hmm *mm_get_hmm(struct mm_struct *mm)
> +{
> + struct hmm *hmm = READ_ONCE(mm->hmm);
> +
> + if (hmm && kref_get_unless_zero(&hmm->kref))
> + return hmm;
> +
> + return NULL;
> +}
> +
> /*
> * hmm_register - register HMM against an mm (HMM internal)
> *
> @@ -67,14 +78,9 @@ struct hmm {
> */
> static struct hmm *hmm_register(struct mm_struct *mm)
> {
> - struct hmm *hmm = READ_ONCE(mm->hmm);
> + struct hmm *hmm = mm_get_hmm(mm);

FWIW: having hmm_register == "hmm get" is a bit confusing...

Ira

> bool cleanup = false;
>
> - /*
> - * The hmm struct can only be freed once the mm_struct goes away,
> - * hence we should always have pre-allocated an new hmm struct
> - * above.
> - */
> if (hmm)
> return hmm;
>
> @@ -86,6 +92,7 @@ static struct hmm *hmm_register(struct mm_struct *mm)
> hmm->mmu_notifier.ops = NULL;
> INIT_LIST_HEAD(&hmm->ranges);
> spin_lock_init(&hmm->lock);
> + kref_init(&hmm->kref);
> hmm->mm = mm;
>
> spin_lock(&mm->page_table_lock);
> @@ -106,7 +113,7 @@ static struct hmm *hmm_register(struct mm_struct *mm)
> if (__mmu_notifier_register(&hmm->mmu_notifier, mm))
> goto error_mm;
>
> - return mm->hmm;
> + return hmm;
>
> error_mm:
> spin_lock(&mm->page_table_lock);
> @@ -118,9 +125,41 @@ static struct hmm *hmm_register(struct mm_struct *mm)
> return NULL;
> }
>
> +static void hmm_free(struct kref *kref)
> +{
> + struct hmm *hmm = container_of(kref, struct hmm, kref);
> + struct mm_struct *mm = hmm->mm;
> +
> + mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
> +
> + spin_lock(&mm->page_table_lock);
> + if (mm->hmm == hmm)
> + mm->hmm = NULL;
> + spin_unlock(&mm->page_table_lock);
> +
> + kfree(hmm);
> +}
> +
> +static inline void hmm_put(struct hmm *hmm)
> +{
> + kref_put(&hmm->kref, hmm_free);
> +}
> +
> void hmm_mm_destroy(struct mm_struct *mm)
> {
> - kfree(mm->hmm);
> + struct hmm *hmm;
> +
> + spin_lock(&mm->page_table_lock);
> + hmm = mm_get_hmm(mm);
> + mm->hmm = NULL;
> + if (hmm) {
> + hmm->mm = NULL;
> + spin_unlock(&mm->page_table_lock);
> + hmm_put(hmm);
> + return;
> + }
> +
> + spin_unlock(&mm->page_table_lock);
> }
>
> static int hmm_invalidate_range(struct hmm *hmm, bool device,
> @@ -165,7 +204,7 @@ static int hmm_invalidate_range(struct hmm *hmm, bool device,
> static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
> {
> struct hmm_mirror *mirror;
> - struct hmm *hmm = mm->hmm;
> + struct hmm *hmm = mm_get_hmm(mm);
>
> down_write(&hmm->mirrors_sem);
> mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror,
> @@ -186,13 +225,16 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
> struct hmm_mirror, list);
> }
> up_write(&hmm->mirrors_sem);
> +
> + hmm_put(hmm);
> }
>
> static int hmm_invalidate_range_start(struct mmu_notifier *mn,
> const struct mmu_notifier_range *range)
> {
> + struct hmm *hmm = mm_get_hmm(range->mm);
> struct hmm_update update;
> - struct hmm *hmm = range->mm->hmm;
> + int ret;
>
> VM_BUG_ON(!hmm);
>
> @@ -200,14 +242,16 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
> update.end = range->end;
> update.event = HMM_UPDATE_INVALIDATE;
> update.blockable = range->blockable;
> - return hmm_invalidate_range(hmm, true, &update);
> + ret = hmm_invalidate_range(hmm, true, &update);
> + hmm_put(hmm);
> + return ret;
> }
>
> static void hmm_invalidate_range_end(struct mmu_notifier *mn,
> const struct mmu_notifier_range *range)
> {
> + struct hmm *hmm = mm_get_hmm(range->mm);
> struct hmm_update update;
> - struct hmm *hmm = range->mm->hmm;
>
> VM_BUG_ON(!hmm);
>
> @@ -216,6 +260,7 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn,
> update.event = HMM_UPDATE_INVALIDATE;
> update.blockable = true;
> hmm_invalidate_range(hmm, false, &update);
> + hmm_put(hmm);
> }
>
> static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
> @@ -241,24 +286,13 @@ int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
> if (!mm || !mirror || !mirror->ops)
> return -EINVAL;
>
> -again:
> mirror->hmm = hmm_register(mm);
> if (!mirror->hmm)
> return -ENOMEM;
>
> down_write(&mirror->hmm->mirrors_sem);
> - if (mirror->hmm->mm == NULL) {
> - /*
> - * A racing hmm_mirror_unregister() is about to destroy the hmm
> - * struct. Try again to allocate a new one.
> - */
> - up_write(&mirror->hmm->mirrors_sem);
> - mirror->hmm = NULL;
> - goto again;
> - } else {
> - list_add(&mirror->list, &mirror->hmm->mirrors);
> - up_write(&mirror->hmm->mirrors_sem);
> - }
> + list_add(&mirror->list, &mirror->hmm->mirrors);
> + up_write(&mirror->hmm->mirrors_sem);
>
> return 0;
> }
> @@ -273,33 +307,18 @@ EXPORT_SYMBOL(hmm_mirror_register);
> */
> void hmm_mirror_unregister(struct hmm_mirror *mirror)
> {
> - bool should_unregister = false;
> - struct mm_struct *mm;
> - struct hmm *hmm;
> + struct hmm *hmm = READ_ONCE(mirror->hmm);
>
> - if (mirror->hmm == NULL)
> + if (hmm == NULL)
> return;
>
> - hmm = mirror->hmm;
> down_write(&hmm->mirrors_sem);
> list_del_init(&mirror->list);
> - should_unregister = list_empty(&hmm->mirrors);
> + /* To protect us against double unregister ... */
> mirror->hmm = NULL;
> - mm = hmm->mm;
> - hmm->mm = NULL;
> up_write(&hmm->mirrors_sem);
>
> - if (!should_unregister || mm == NULL)
> - return;
> -
> - mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
> -
> - spin_lock(&mm->page_table_lock);
> - if (mm->hmm == hmm)
> - mm->hmm = NULL;
> - spin_unlock(&mm->page_table_lock);
> -
> - kfree(hmm);
> + hmm_put(hmm);
> }
> EXPORT_SYMBOL(hmm_mirror_unregister);
>
> @@ -708,6 +727,8 @@ int hmm_vma_get_pfns(struct hmm_range *range)
> struct mm_walk mm_walk;
> struct hmm *hmm;
>
> + range->hmm = NULL;
> +
> /* Sanity check, this really should not happen ! */
> if (range->start < vma->vm_start || range->start >= vma->vm_end)
> return -EINVAL;
> @@ -717,14 +738,18 @@ int hmm_vma_get_pfns(struct hmm_range *range)
> hmm = hmm_register(vma->vm_mm);
> if (!hmm)
> return -ENOMEM;
> - /* Caller must have registered a mirror, via hmm_mirror_register() ! */
> - if (!hmm->mmu_notifier.ops)
> +
> + /* Check if hmm_mm_destroy() was call. */
> + if (hmm->mm == NULL) {
> + hmm_put(hmm);
> return -EINVAL;
> + }
>
> /* FIXME support hugetlb fs */
> if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) ||
> vma_is_dax(vma)) {
> hmm_pfns_special(range);
> + hmm_put(hmm);
> return -EINVAL;
> }
>
> @@ -736,6 +761,7 @@ int hmm_vma_get_pfns(struct hmm_range *range)
> * operations such has atomic access would not work.
> */
> hmm_pfns_clear(range, range->pfns, range->start, range->end);
> + hmm_put(hmm);
> return -EPERM;
> }
>
> @@ -758,6 +784,12 @@ int hmm_vma_get_pfns(struct hmm_range *range)
> mm_walk.pte_hole = hmm_vma_walk_hole;
>
> walk_page_range(range->start, range->end, &mm_walk);
> + /*
> + * Transfer hmm reference to the range struct it will be drop inside
> + * the hmm_vma_range_done() function (which _must_ be call if this
> + * function return 0).
> + */
> + range->hmm = hmm;
> return 0;
> }
> EXPORT_SYMBOL(hmm_vma_get_pfns);
> @@ -802,25 +834,27 @@ EXPORT_SYMBOL(hmm_vma_get_pfns);
> */
> bool hmm_vma_range_done(struct hmm_range *range)
> {
> - unsigned long npages = (range->end - range->start) >> PAGE_SHIFT;
> - struct hmm *hmm;
> + bool ret = false;
>
> - if (range->end <= range->start) {
> + /* Sanity check this really should not happen. */
> + if (range->hmm == NULL || range->end <= range->start) {
> BUG();
> return false;
> }
>
> - hmm = hmm_register(range->vma->vm_mm);
> - if (!hmm) {
> - memset(range->pfns, 0, sizeof(*range->pfns) * npages);
> - return false;
> - }
> -
> - spin_lock(&hmm->lock);
> + spin_lock(&range->hmm->lock);
> list_del_rcu(&range->list);
> - spin_unlock(&hmm->lock);
> + ret = range->valid;
> + spin_unlock(&range->hmm->lock);
>
> - return range->valid;
> + /* Is the mm still alive ? */
> + if (range->hmm->mm == NULL)
> + ret = false;
> +
> + /* Drop reference taken by hmm_vma_fault() or hmm_vma_get_pfns() */
> + hmm_put(range->hmm);
> + range->hmm = NULL;
> + return ret;
> }
> EXPORT_SYMBOL(hmm_vma_range_done);
>
> @@ -880,6 +914,8 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
> struct hmm *hmm;
> int ret;
>
> + range->hmm = NULL;
> +
> /* Sanity check, this really should not happen ! */
> if (range->start < vma->vm_start || range->start >= vma->vm_end)
> return -EINVAL;
> @@ -891,14 +927,18 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
> hmm_pfns_clear(range, range->pfns, range->start, range->end);
> return -ENOMEM;
> }
> - /* Caller must have registered a mirror using hmm_mirror_register() */
> - if (!hmm->mmu_notifier.ops)
> +
> + /* Check if hmm_mm_destroy() was call. */
> + if (hmm->mm == NULL) {
> + hmm_put(hmm);
> return -EINVAL;
> + }
>
> /* FIXME support hugetlb fs */
> if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) ||
> vma_is_dax(vma)) {
> hmm_pfns_special(range);
> + hmm_put(hmm);
> return -EINVAL;
> }
>
> @@ -910,6 +950,7 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
> * operations such has atomic access would not work.
> */
> hmm_pfns_clear(range, range->pfns, range->start, range->end);
> + hmm_put(hmm);
> return -EPERM;
> }
>
> @@ -945,7 +986,16 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
> hmm_pfns_clear(range, &range->pfns[i], hmm_vma_walk.last,
> range->end);
> hmm_vma_range_done(range);
> + hmm_put(hmm);
> + } else {
> + /*
> + * Transfer hmm reference to the range struct it will be drop
> + * inside the hmm_vma_range_done() function (which _must_ be
> + * call if this function return 0).
> + */
> + range->hmm = hmm;
> }
> +
> return ret;
> }
> EXPORT_SYMBOL(hmm_vma_fault);
> --
> 2.17.2
>