Re: [PATCH v4 26/66] mm/mmap: Reorganize munmap to use maple states
From: Liam Howlett
Date: Fri Jan 21 2022 - 14:32:28 EST
* Vlastimil Babka <vbabka@xxxxxxx> [220118 05:40]:
> On 12/1/21 15:30, Liam Howlett wrote:
> > From: "Liam R. Howlett" <Liam.Howlett@xxxxxxxxxx>
> >
> > Remove __do_munmap() in favour of do_munmap(), do_mas_munmap(), and
> > do_mas_align_munmap().
> >
> > do_munmap() is a wrapper to create a maple state for any callers that
> > have not been converted to the maple tree.
> >
> > do_mas_munmap() takes a maple state to mumap a range. This is just a
> > small function which checks for error conditions and aligns the end of
> > the range.
> >
> > do_mas_align_munmap() uses the aligned range to mumap a range.
> > do_mas_align_munmap() starts with the first VMA in the range, then finds
> > the last VMA in the range. Both start and end are split if necessary.
> > Then the VMAs are unlocked and removed from the linked list at the same
> > time. Followed by a single tree operation of overwriting the area in
> > with a NULL. Finally, the detached list is unmapped and freed.
> >
> > By reorganizing the munmap calls as outlined, it is now possible to
> > avoid extra work of aligning pre-aligned callers which are known to be
> > safe, avoid extra VMA lookups or tree walks for modifications.
> >
> > detach_vmas_to_be_unmapped() is no longer used, so drop this code.
> >
> > Signed-off-by: Liam R. Howlett <Liam.Howlett@xxxxxxxxxx>
>
> <snip>
>
> > -/* Munmap is split into 2 main parts -- this part which finds
> > - * what needs doing, and the areas themselves, which do the
> > - * work. This now handles partial unmappings.
> > - * Jeremy Fitzhardinge <jeremy@xxxxxxxx>
> > +/*
> > + * do_mas_align_munmap() - munmap the aligned region from @start to @end.
> > + * @mas: The maple_state, ideally set up to alter the correct tree location.
> > + * @vma: The starting vm_area_struct
> > + * @mm: The mm_struct
> > + * @start: The aligned start address to munmap.
> > + * @end: The aligned end address to munmap.
> > + * @uf: The userfaultfd list_head
> > + * @downgrade: Set to true to attempt a downwrite of the mmap_sem
>
> s/downwrite/write downgrade/?
Yes, thanks.
>
> > + *
> > + * @mas must be locked before calling this function. If @downgrade is true,
> > + * check return code for potential release of the lock.
>
> How is 'mas' locked? The downgrade still calls mmap_write_downgrade(mm). It
> should say "mm's mmap_lock should be write locked" no?
yes, this comment should have been updated with the change to the
locking. In fact, it should be clear that the lock must be held so I'll
drop this part of the comments entirely.
>
> > */
> > -int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
> > - struct list_head *uf, bool downgrade)
> > +static int
> > +do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
> > + struct mm_struct *mm, unsigned long start,
> > + unsigned long end, struct list_head *uf, bool downgrade)
> > {
> > - unsigned long end;
> > - struct vm_area_struct *vma, *prev, *last;
> > -
> > - if ((offset_in_page(start)) || start > TASK_SIZE || len > TASK_SIZE-start)
> > - return -EINVAL;
> > -
> > - len = PAGE_ALIGN(len);
> > - end = start + len;
> > - if (len == 0)
> > - return -EINVAL;
> > -
> > - /* arch_unmap() might do unmaps itself. */
> > - arch_unmap(mm, start, end);
> > -
> > - /* Find the first overlapping VMA where start < vma->vm_end */
> > - vma = find_vma_intersection(mm, start, end);
> > - if (!vma)
> > - return 0;
> > - prev = vma->vm_prev;
> > + struct vm_area_struct *prev, *last;
> > /* we have start < vma->vm_end */
> >
> > /*
> > @@ -2458,16 +2418,26 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
> > if (error)
> > return error;
> > prev = vma;
> > + vma = __vma_next(mm, prev);
> > + mas->index = start;
> > + mas_reset(mas);
> > + } else {
> > + prev = vma->vm_prev;
> > }
> >
> > + if (vma->vm_end >= end)
> > + last = vma;
> > + else
> > + last = find_vma_intersection(mm, end - 1, end);
> > +
> > /* Does it split the last one? */
> > - last = find_vma(mm, end);
> > - if (last && end > last->vm_start) {
> > + if (last && end < last->vm_end) {
> > int error = __split_vma(mm, last, end, 1);
> > if (error)
> > return error;
> > + vma = __vma_next(mm, prev);
>
> Should be needed only if last == vma?
Yes, it's safe but probably better to do if last == vma. This is to do
with the linked list so it eventually is removed entirely.
>
> > + mas_reset(mas);
> > }
> > - vma = __vma_next(mm, prev);
> >
> > if (unlikely(uf)) {
> > /*
> > @@ -2480,22 +2450,47 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
> > * failure that it's not worth optimizing it for.
> > */
> > int error = userfaultfd_unmap_prep(vma, start, end, uf);
> > +
> > if (error)
> > return error;
> > }
> >
> > /*
> > - * unlock any mlock()ed ranges before detaching vmas
> > + * unlock any mlock()ed ranges before detaching vmas, count the number
> > + * of VMAs to be dropped, and return the tail entry of the affected
> > + * area.
> > */
> > - if (mm->locked_vm)
> > - unlock_range(vma, end);
> > + mm->map_count -= unlock_range(vma, &last, end);
> > + /* Drop removed area from the tree */
> > + mas_store_gfp(mas, NULL, GFP_KERNEL);
> >
> > - /* Detach vmas from the MM linked list and remove from the mm tree*/
> > - if (!detach_vmas_to_be_unmapped(mm, vma, prev, end))
> > - downgrade = false;
> > + /* Detach vmas from the MM linked list */
> > + vma->vm_prev = NULL;
> > + if (prev)
> > + prev->vm_next = last->vm_next;
> > + else
> > + mm->mmap = last->vm_next;
> >
> > - if (downgrade)
> > - mmap_write_downgrade(mm);
> > + if (last->vm_next) {
> > + last->vm_next->vm_prev = prev;
> > + last->vm_next = NULL;
> > + } else
> > + mm->highest_vm_end = prev ? vm_end_gap(prev) : 0;
> > +
> > + /*
> > + * Do not downgrade mmap_lock if we are next to VM_GROWSDOWN or
> > + * VM_GROWSUP VMA. Such VMAs can change their size under
> > + * down_read(mmap_lock) and collide with the VMA we are about to unmap.
> > + */
> > + if (downgrade) {
> > + if (last && (last->vm_flags & VM_GROWSDOWN))
> > + downgrade = false;
> > + else if (prev && (prev->vm_flags & VM_GROWSUP))
> > + downgrade = false;
> > + else {
> > + mmap_write_downgrade(mm);
> > + }
>
> remove { } brackets?
Yes, thanks.
>
> > + }
> >
> > unmap_region(mm, vma, prev, start, end);
> >
> > @@ -2505,10 +2500,61 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
> > return downgrade ? 1 : 0;
> > }
> >
> > +/*
> > + * do_mas_munmap() - munmap a given range.
> > + * @mas: The maple state
> > + * @mm: The mm_struct
> > + * @start: The start address to munmap
> > + * @len: The length of the range to munmap
> > + * @uf: The userfaultfd list_head
> > + * @downgrade: set to true if the user wants to attempt to write_downgrade the
> > + * mmap_sem
> > + *
> > + * This function takes a @mas that is in the correct state to remove the
> > + * mapping(s). The @len will be aligned and any arch_unmap work will be
> > + * preformed.
> > + * @mas must be locked. @mas may be unlocked if @degraded is true.
This comment needs updating too.
> > + *
> > + * Returns: -EINVAL on failure, 1 on success and unlock, 0 otherwise.
> > + */
> > +int do_mas_munmap(struct ma_state *mas, struct mm_struct *mm,
> > + unsigned long start, size_t len, struct list_head *uf,
> > + bool downgrade)
> > +{
> > + unsigned long end;
> > + struct vm_area_struct *vma;
> > +
> > + if ((offset_in_page(start)) || start > TASK_SIZE || len > TASK_SIZE-start)
> > + return -EINVAL;
> > +
> > + end = start + PAGE_ALIGN(len);
> > + if (end == start)
> > + return -EINVAL;
> > +
> > + /* arch_unmap() might do unmaps itself. */
> > + arch_unmap(mm, start, end);
> > +
> > + /* Find the first overlapping VMA */
> > + vma = mas_find(mas, end - 1);
> > + if (!vma)
> > + return 0;
> > +
> > + mas->last = end - 1;
>
> Why not set this before mas_find() above? Hm but that takes its own second
> parameter instead of looking at mas->last. To be honest, I'm a bit confused
> wrt the role of mas->last in the API. Perhaps another suggestion for the
> "how to improve docs" discussion earlier. Or maybe I just missed/forgot it.
Hmm, maybe the doc doesn't specifically say what it is but the header
does:
unsigned long last; /* The last index we're operating on -
range end */
I'll have a look at the documentation to make sure it's more clear.
mas->index and mas->last represents the range. If you walk to an entry
then mas->index is set to the start address and mas->last is the
inclusive end of that entry. I set mas->last because we need to munmap
from the first overlapping VMA to the end of the range, but there may be
more than one VMA which would mean mas->last is less than end - 1 at
this point.
>
> > + return do_mas_align_munmap(mas, vma, mm, start, end, uf, downgrade);
> > +}
> > +
>
> <snip>
>
> > diff --git a/mm/mremap.c b/mm/mremap.c
> > index 002eec83e91e..b09e107cd18b 100644
> > --- a/mm/mremap.c
> > +++ b/mm/mremap.c
> > @@ -978,20 +978,23 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
> > /*
> > * Always allow a shrinking remap: that just unmaps
> > * the unnecessary pages..
> > - * __do_munmap does all the needed commit accounting, and
> > + * do_mas_munmap does all the needed commit accounting, and
> > * downgrades mmap_lock to read if so directed.
> > */
> > if (old_len >= new_len) {
> > int retval;
> > + MA_STATE(mas, &mm->mm_mt, addr + new_len, addr + new_len);
> >
> > - retval = __do_munmap(mm, addr+new_len, old_len - new_len,
> > - &uf_unmap, true);
> > - if (retval < 0 && old_len != new_len) {
> > - ret = retval;
> > - goto out;
> > + retval = do_mas_munmap(&mas, mm, addr + new_len,
> > + old_len - new_len, &uf_unmap, true);
> > /* Returning 1 indicates mmap_lock is downgraded to read. */
> > - } else if (retval == 1)
> > + if (retval == 1) {
> > downgraded = true;
> > + } else if (retval < 0 && old_len != new_len) {
> > + ret = retval;
> > + goto out;
> > + }
> > +
> > ret = addr;
> > goto out;
> > }
> > @@ -1006,7 +1009,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
> > }
> >
> > /* old_len exactly to the end of the area..
> > - */
> > + */
>
> Spurious edit?
>
Ack
> > if (old_len == vma->vm_end - addr) {
> > /* can we just expand the current mapping? */
> > if (vma_expandable(vma, new_len - old_len)) {
> > @@ -1048,9 +1051,9 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
> > map_flags |= MAP_SHARED;
> >
> > new_addr = get_unmapped_area(vma->vm_file, 0, new_len,
> > - vma->vm_pgoff +
> > - ((addr - vma->vm_start) >> PAGE_SHIFT),
> > - map_flags);
> > + vma->vm_pgoff +
> > + ((addr - vma->vm_start) >> PAGE_SHIFT),
> > + map_flags);
>
> And this?
Ack
>
> > if (IS_ERR_VALUE(new_addr)) {
> > ret = new_addr;
> > goto out;
>
>