[PATCH 2/5] mm/fork: Pass new vma pointer into copy_page_range()

From: Peter Xu
Date: Mon Sep 21 2020 - 17:18:17 EST


This prepares for the future work to trigger early cow on pinned pages during
fork(). No functional change intended.

Signed-off-by: Peter Xu <peterx@xxxxxxxxxx>
---
include/linux/mm.h | 2 +-
kernel/fork.c | 2 +-
mm/memory.c | 14 +++++++++-----
3 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/include/linux/mm.h b/include/linux/mm.h
index ca6e6a81576b..bf1ac54be55e 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1644,7 +1644,7 @@ struct mmu_notifier_range;
void free_pgd_range(struct mmu_gather *tlb, unsigned long addr,
unsigned long end, unsigned long floor, unsigned long ceiling);
int copy_page_range(struct mm_struct *dst, struct mm_struct *src,
- struct vm_area_struct *vma);
+ struct vm_area_struct *vma, struct vm_area_struct *new);
int follow_pte_pmd(struct mm_struct *mm, unsigned long address,
struct mmu_notifier_range *range,
pte_t **ptepp, pmd_t **pmdpp, spinlock_t **ptlp);
diff --git a/kernel/fork.c b/kernel/fork.c
index 7237d418e7b5..843807ade6dd 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -589,7 +589,7 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,

mm->map_count++;
if (!(tmp->vm_flags & VM_WIPEONFORK))
- retval = copy_page_range(mm, oldmm, mpnt);
+ retval = copy_page_range(mm, oldmm, mpnt, tmp);

if (tmp->vm_ops && tmp->vm_ops->open)
tmp->vm_ops->open(tmp);
diff --git a/mm/memory.c b/mm/memory.c
index 469af373ae76..7525147908c4 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -814,6 +814,7 @@ copy_one_pte(struct mm_struct *dst_mm, struct mm_struct *src_mm,

static int copy_pte_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
pmd_t *dst_pmd, pmd_t *src_pmd, struct vm_area_struct *vma,
+ struct vm_area_struct *new,
unsigned long addr, unsigned long end)
{
pte_t *orig_src_pte, *orig_dst_pte;
@@ -877,6 +878,7 @@ static int copy_pte_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,

static inline int copy_pmd_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
pud_t *dst_pud, pud_t *src_pud, struct vm_area_struct *vma,
+ struct vm_area_struct *new,
unsigned long addr, unsigned long end)
{
pmd_t *src_pmd, *dst_pmd;
@@ -903,7 +905,7 @@ static inline int copy_pmd_range(struct mm_struct *dst_mm, struct mm_struct *src
if (pmd_none_or_clear_bad(src_pmd))
continue;
if (copy_pte_range(dst_mm, src_mm, dst_pmd, src_pmd,
- vma, addr, next))
+ vma, new, addr, next))
return -ENOMEM;
} while (dst_pmd++, src_pmd++, addr = next, addr != end);
return 0;
@@ -911,6 +913,7 @@ static inline int copy_pmd_range(struct mm_struct *dst_mm, struct mm_struct *src

static inline int copy_pud_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
p4d_t *dst_p4d, p4d_t *src_p4d, struct vm_area_struct *vma,
+ struct vm_area_struct *new,
unsigned long addr, unsigned long end)
{
pud_t *src_pud, *dst_pud;
@@ -937,7 +940,7 @@ static inline int copy_pud_range(struct mm_struct *dst_mm, struct mm_struct *src
if (pud_none_or_clear_bad(src_pud))
continue;
if (copy_pmd_range(dst_mm, src_mm, dst_pud, src_pud,
- vma, addr, next))
+ vma, new, addr, next))
return -ENOMEM;
} while (dst_pud++, src_pud++, addr = next, addr != end);
return 0;
@@ -945,6 +948,7 @@ static inline int copy_pud_range(struct mm_struct *dst_mm, struct mm_struct *src

static inline int copy_p4d_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
pgd_t *dst_pgd, pgd_t *src_pgd, struct vm_area_struct *vma,
+ struct vm_area_struct *new,
unsigned long addr, unsigned long end)
{
p4d_t *src_p4d, *dst_p4d;
@@ -959,14 +963,14 @@ static inline int copy_p4d_range(struct mm_struct *dst_mm, struct mm_struct *src
if (p4d_none_or_clear_bad(src_p4d))
continue;
if (copy_pud_range(dst_mm, src_mm, dst_p4d, src_p4d,
- vma, addr, next))
+ vma, new, addr, next))
return -ENOMEM;
} while (dst_p4d++, src_p4d++, addr = next, addr != end);
return 0;
}

int copy_page_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
- struct vm_area_struct *vma)
+ struct vm_area_struct *vma, struct vm_area_struct *new)
{
pgd_t *src_pgd, *dst_pgd;
unsigned long next;
@@ -1021,7 +1025,7 @@ int copy_page_range(struct mm_struct *dst_mm, struct mm_struct *src_mm,
if (pgd_none_or_clear_bad(src_pgd))
continue;
if (unlikely(copy_p4d_range(dst_mm, src_mm, dst_pgd, src_pgd,
- vma, addr, next))) {
+ vma, new, addr, next))) {
ret = -ENOMEM;
break;
}
--
2.26.2