[PATCH] mm: Do early cow for FOLL_PIN pages during fork()

From: Peter Xu
Date: Fri Sep 18 2020 - 11:56:27 EST


Co-developed-by: Jason Gunthorpe <jgg@xxxxxxxx>
Signed-off-by: Peter Xu <peterx@xxxxxxxxxx>
---
include/linux/mm.h | 2 +-
include/linux/mm_types.h | 1 +
kernel/fork.c | 3 +-
mm/gup.c | 15 +++++++
mm/memory.c | 85 +++++++++++++++++++++++++++++++++++-----
5 files changed, 94 insertions(+), 12 deletions(-)

diff --git a/include/linux/mm.h b/include/linux/mm.h
index ca6e6a81576b..bf1ac54be55e 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1644,7 +1644,7 @@ struct mmu_notifier_range;
void free_pgd_range(struct mmu_gather *tlb, unsigned long addr,
unsigned long end, unsigned long floor, unsigned long ceiling);
int copy_page_range(struct mm_struct *dst, struct mm_struct *src,
- struct vm_area_struct *vma);
+ struct vm_area_struct *vma, struct vm_area_struct *new);
int follow_pte_pmd(struct mm_struct *mm, unsigned long address,
struct mmu_notifier_range *range,
pte_t **ptepp, pmd_t **pmdpp, spinlock_t **ptlp);
diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 496c3ff97cce..b3812fa6383f 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -458,6 +458,7 @@ struct mm_struct {

unsigned long total_vm; /* Total pages mapped */
unsigned long locked_vm; /* Pages that have PG_mlocked set */
+ unsigned long has_pinned; /* Whether this mm has pinned any page */
atomic64_t pinned_vm; /* Refcount permanently increased */
unsigned long data_vm; /* VM_WRITE & ~VM_SHARED & ~VM_STACK */
unsigned long exec_vm; /* VM_EXEC & ~VM_WRITE & ~VM_STACK */
diff --git a/kernel/fork.c b/kernel/fork.c
index 49677d668de4..843807ade6dd 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -589,7 +589,7 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,

mm->map_count++;
if (!(tmp->vm_flags & VM_WIPEONFORK))
- retval = copy_page_range(mm, oldmm, mpnt);
+ retval = copy_page_range(mm, oldmm, mpnt, tmp);

if (tmp->vm_ops && tmp->vm_ops->open)
tmp->vm_ops->open(tmp);
@@ -1011,6 +1011,7 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
mm_pgtables_bytes_init(mm);
mm->map_count = 0;
mm->locked_vm = 0;
+ mm->has_pinned = 0;
atomic64_set(&mm->pinned_vm, 0);
memset(&mm->rss_stat, 0, sizeof(mm->rss_stat));
spin_lock_init(&mm->page_table_lock);
diff --git a/mm/gup.c b/mm/gup.c
index e5739a1974d5..cab10cefefe4 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -1255,6 +1255,17 @@ static __always_inline long __get_user_pages_locked(struct mm_struct *mm,
BUG_ON(*locked != 1);
}

+ /*
+ * Mark the mm struct if there's any page pinning attempt. We're
+ * aggresive on this bit since even if the pinned pages were unpinned
+ * later on, we'll still keep this bit set for this address space just
+ * to make everything easy.
+ *
+ * TODO: Ideally we can use mm->pinned_vm but only until it's stable.
+ */
+ if (flags & FOLL_PIN)
+ WRITE_ONCE(mm->has_pinned, 1);
+
/*
* FOLL_PIN and FOLL_GET are mutually exclusive. Traditional behavior
* is to set FOLL_GET if the caller wants pages[] filled in (but has
@@ -2660,6 +2671,10 @@ static int internal_get_user_pages_fast(unsigned long start, int nr_pages,
FOLL_FAST_ONLY)))
return -EINVAL;

+ /* Please refer to comments in __get_user_pages_locked() */
+ if (gup_flags & FOLL_PIN)
+ WRITE_ONCE(current->mm->has_pinned, 1);
+
if (!(gup_flags & FOLL_FAST_ONLY))
might_lock_read(&current->mm->mmap_lock);

diff --git a/mm/memory.c b/mm/memory.c
index 469af373ae76..be9e2946c462 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -689,6 +689,38 @@ struct page *vm_normal_page_pmd(struct vm_area_struct *vma, unsigned long addr,
}
#endif

+/*
+ * Do early cow for the page and the pte. Return true if page duplicate
+ * succeeded, false otherwise.
+ */
+static bool duplicate_page(struct mm_struct *mm, struct vm_area_struct *vma,
+ unsigned long address, struct page *page,
+ pte_t *newpte)
+{
+ struct page *new_page;
+ pte_t entry;
+
+ new_page = alloc_page_vma(GFP_HIGHUSER_MOVABLE, vma, address);
+ if (!new_page)
+ return false;
+
+ copy_user_highpage(new_page, page, address, vma);
+ if (mem_cgroup_charge(new_page, mm, GFP_KERNEL)) {
+ put_page(new_page);
+ return false;
+ }
+ cgroup_throttle_swaprate(new_page, GFP_KERNEL);
+ __SetPageUptodate(new_page);
+
+ entry = mk_pte(new_page, vma->vm_page_prot);
+ entry = pte_sw_mkyoung(entry);
+ entry = maybe_mkwrite(pte_mkdirty(entry), vma);
+ *newpte = entry;
+ page_add_new_anon_rmap(new_page, vma, address, false);
+
+ return true;
+}
+
/*
* copy one vm_area from one task to the other. Assumes the page tables
* already present in the new task to be cleared in the whole range
@@ -698,11 +730,13 @@ struct page *vm_normal_page_pmd(struct vm_area_struct *vma, unsigned long addr,
static inline unsigned long
copy_one_pte(struct mm_struct *dst_mm, struct mm_struct *src_mm,
pte_t *dst_pte, pte_t *src_pte, struct vm_area_struct *vma,
+ struct vm_area_struct *new,
unsigned long addr, int *rss)
{
unsigned long vm_flags = vma->vm_flags;
pte_t pte = *src_pte;
struct page *page;
+ bool wp;

/* pte contains position in swap or file, so copy. */
if (unlikely(!pte_present(pte))) {
@@ -779,10 +813,7 @@ copy_one_pte(struct mm_struct *dst_mm, struct mm_struct *src_mm,
* If it's a COW mapping, write protect it both
* in the parent and the child
*/
- if (is_cow_mapping(vm_flags) && pte_write(pte)) {
- ptep_set_wrprotect(src_mm, addr, src_pte);
- pte = pte_wrprotect(pte);
- }
+ wp = is_cow_mapping(vm_flags) && pte_write(pte);

/*
* If it's a shared mapping, mark it clean in
@@ -805,6 +836,36 @@ copy_one_pte(struct mm_struct *dst_mm, struct mm_struct *src_mm,
get_page(page);
page_dup_rmap(page, false);
rss[mm_counter(page)]++;
+
+ /*
+ * If the page is pinned in source mm, do early cow right now
+ * so that the pinned page won't be replaced by another random
+ * page without being noticed after the fork().
+ *
+ * Note: There can be some very rare cases that we'll do
+ * unnecessary cow here, due to page_maybe_dma_pinned() is
+ * sometimes bogus, and has_pinned flag is currently aggresive
+ * too. However this should be good enough for us for now as
+ * long as we covered all the pinned pages.
+ *
+ * If page copy triggered successfully, undo wp. Otherwise we
+ * probably don't have enough memory for use, we can either
+ * report up so we fail the fork(), or for simplicity we keep
+ * the old behavior of keep the page shared.
+ *
+ * Note: we passed in src_mm here for accounting because
+ * copy_namespaces() should not have been done when copy_mm()
+ * happens.
+ */
+ if (wp && READ_ONCE(src_mm->has_pinned) &&
+ page_maybe_dma_pinned(page) &&
+ duplicate_page(src_mm, new, addr, page, &pte))
+ wp = false;
+ }
+
+ if (wp) {
+ ptep_set_wrprotect(src_mm, addr, src_pte);
+ pte = pte_wrprotect(pte);
}

out_set_pte:
@@ -814,6 +875,7 @@ copy_one_pte(struct mm_struct *dst_mm, struct mm_struct *src_mm,

static int copy_pte_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
pmd_t *dst_pmd, pmd_t *src_pmd, struct vm_area_struct *vma,
+ struct vm_area_struct *new,
unsigned long addr, unsigned long end)
{
pte_t *orig_src_pte, *orig_dst_pte;
@@ -852,7 +914,7 @@ static int copy_pte_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
continue;
}
entry.val = copy_one_pte(dst_mm, src_mm, dst_pte, src_pte,
- vma, addr, rss);
+ vma, new, addr, rss);
if (entry.val)
break;
progress += 8;
@@ -877,6 +939,7 @@ static int copy_pte_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,

static inline int copy_pmd_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
pud_t *dst_pud, pud_t *src_pud, struct vm_area_struct *vma,
+ struct vm_area_struct *new,
unsigned long addr, unsigned long end)
{
pmd_t *src_pmd, *dst_pmd;
@@ -903,7 +966,7 @@ static inline int copy_pmd_range(struct mm_struct *dst_mm, struct mm_struct *src
if (pmd_none_or_clear_bad(src_pmd))
continue;
if (copy_pte_range(dst_mm, src_mm, dst_pmd, src_pmd,
- vma, addr, next))
+ vma, new, addr, next))
return -ENOMEM;
} while (dst_pmd++, src_pmd++, addr = next, addr != end);
return 0;
@@ -911,6 +974,7 @@ static inline int copy_pmd_range(struct mm_struct *dst_mm, struct mm_struct *src

static inline int copy_pud_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
p4d_t *dst_p4d, p4d_t *src_p4d, struct vm_area_struct *vma,
+ struct vm_area_struct *new,
unsigned long addr, unsigned long end)
{
pud_t *src_pud, *dst_pud;
@@ -937,7 +1001,7 @@ static inline int copy_pud_range(struct mm_struct *dst_mm, struct mm_struct *src
if (pud_none_or_clear_bad(src_pud))
continue;
if (copy_pmd_range(dst_mm, src_mm, dst_pud, src_pud,
- vma, addr, next))
+ vma, new, addr, next))
return -ENOMEM;
} while (dst_pud++, src_pud++, addr = next, addr != end);
return 0;
@@ -945,6 +1009,7 @@ static inline int copy_pud_range(struct mm_struct *dst_mm, struct mm_struct *src

static inline int copy_p4d_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
pgd_t *dst_pgd, pgd_t *src_pgd, struct vm_area_struct *vma,
+ struct vm_area_struct *new,
unsigned long addr, unsigned long end)
{
p4d_t *src_p4d, *dst_p4d;
@@ -959,14 +1024,14 @@ static inline int copy_p4d_range(struct mm_struct *dst_mm, struct mm_struct *src
if (p4d_none_or_clear_bad(src_p4d))
continue;
if (copy_pud_range(dst_mm, src_mm, dst_p4d, src_p4d,
- vma, addr, next))
+ vma, new, addr, next))
return -ENOMEM;
} while (dst_p4d++, src_p4d++, addr = next, addr != end);
return 0;
}

int copy_page_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
- struct vm_area_struct *vma)
+ struct vm_area_struct *vma, struct vm_area_struct *new)
{
pgd_t *src_pgd, *dst_pgd;
unsigned long next;
@@ -1021,7 +1086,7 @@ int copy_page_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
if (pgd_none_or_clear_bad(src_pgd))
continue;
if (unlikely(copy_p4d_range(dst_mm, src_mm, dst_pgd, src_pgd,
- vma, addr, next))) {
+ vma, new, addr, next))) {
ret = -ENOMEM;
break;
}
--
2.26.2


--n8g4imXOkfNTN/H1--