[PATCH] mm/mprotect: add a mkwrite paramater to change_protection()

From: JÃrÃme Glisse
Date: Thu Sep 13 2018 - 10:16:30 EST


The mkwrite parameter allow to change read only pte to write one which
is needed by userfaultfd to un-write-protect after a fault have been
handled.

Signed-off-by: Jéme Glisse <jglisse@xxxxxxxxxx>
---
include/linux/huge_mm.h | 2 +-
include/linux/mm.h | 3 ++-
mm/huge_memory.c | 5 ++++-
mm/mempolicy.c | 2 +-
mm/mprotect.c | 37 +++++++++++++++++++++----------------
mm/userfaultfd.c | 2 +-
6 files changed, 30 insertions(+), 21 deletions(-)

diff --git a/include/linux/huge_mm.h b/include/linux/huge_mm.h
index a8a126259bc4..b51ff7f8e65c 100644
--- a/include/linux/huge_mm.h
+++ b/include/linux/huge_mm.h
@@ -45,7 +45,7 @@ extern bool move_huge_pmd(struct vm_area_struct *vma, unsigned long old_addr,
pmd_t *old_pmd, pmd_t *new_pmd, bool *need_flush);
extern int change_huge_pmd(struct vm_area_struct *vma, pmd_t *pmd,
unsigned long addr, pgprot_t newprot,
- int prot_numa);
+ int prot_numa, bool mkwrite);
int vmf_insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
pmd_t *pmd, pfn_t pfn, bool write);
int vmf_insert_pfn_pud(struct vm_area_struct *vma, unsigned long addr,
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 5d5c7fd07dc0..2bbf3e33bf9e 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1492,7 +1492,8 @@ extern unsigned long move_page_tables(struct vm_area_struct *vma,
bool need_rmap_locks);
extern unsigned long change_protection(struct vm_area_struct *vma, unsigned long start,
unsigned long end, pgprot_t newprot,
- int dirty_accountable, int prot_numa);
+ int dirty_accountable, int prot_numa,
+ bool mkwrite);
extern int mprotect_fixup(struct vm_area_struct *vma,
struct vm_area_struct **pprev, unsigned long start,
unsigned long end, unsigned long newflags);
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index abf621aba672..49853f0b1570 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -1842,7 +1842,8 @@ bool move_huge_pmd(struct vm_area_struct *vma, unsigned long old_addr,
* - HPAGE_PMD_NR is protections changed and TLB flush necessary
*/
int change_huge_pmd(struct vm_area_struct *vma, pmd_t *pmd,
- unsigned long addr, pgprot_t newprot, int prot_numa)
+ unsigned long addr, pgprot_t newprot, int prot_numa,
+ bool mkwrite)
{
struct mm_struct *mm = vma->vm_mm;
spinlock_t *ptl;
@@ -1925,6 +1926,8 @@ int change_huge_pmd(struct vm_area_struct *vma, pmd_t *pmd,
entry = pmd_modify(entry, newprot);
if (preserve_write)
entry = pmd_mk_savedwrite(entry);
+ if (mkwrite)
+ entry = pmd_mkwrite(entry);
ret = HPAGE_PMD_NR;
set_pmd_at(mm, addr, pmd, entry);
BUG_ON(vma_is_anonymous(vma) && !preserve_write && pmd_write(entry));
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 4ce44d3ff03d..2d0ee09e6b26 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -579,7 +579,7 @@ unsigned long change_prot_numa(struct vm_area_struct *vma,
{
int nr_updated;

- nr_updated = change_protection(vma, addr, end, PAGE_NONE, 0, 1);
+ nr_updated = change_protection(vma, addr, end, PAGE_NONE, 0, 1, false);
if (nr_updated)
count_vm_numa_events(NUMA_PTE_UPDATES, nr_updated);

diff --git a/mm/mprotect.c b/mm/mprotect.c
index 58b629bb70de..792669cfb7e1 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -36,7 +36,7 @@

static unsigned long change_pte_range(struct vm_area_struct *vma, pmd_t *pmd,
unsigned long addr, unsigned long end, pgprot_t newprot,
- int dirty_accountable, int prot_numa)
+ int dirty_accountable, int prot_numa, bool mkwrite)
{
struct mm_struct *mm = vma->vm_mm;
pte_t *pte, oldpte;
@@ -102,9 +102,9 @@ static unsigned long change_pte_range(struct vm_area_struct *vma, pmd_t *pmd,
ptent = pte_mk_savedwrite(ptent);

/* Avoid taking write faults for known dirty pages */
- if (dirty_accountable && pte_dirty(ptent) &&
- (pte_soft_dirty(ptent) ||
- !(vma->vm_flags & VM_SOFTDIRTY))) {
+ if (enable_write || (dirty_accountable &&
+ pte_dirty(ptent) && (pte_soft_dirty(ptent) ||
+ !(vma->vm_flags & VM_SOFTDIRTY)))) {
ptent = pte_mkwrite(ptent);
}
ptep_modify_prot_commit(mm, addr, pte, ptent);
@@ -150,7 +150,8 @@ static unsigned long change_pte_range(struct vm_area_struct *vma, pmd_t *pmd,

static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
pud_t *pud, unsigned long addr, unsigned long end,
- pgprot_t newprot, int dirty_accountable, int prot_numa)
+ pgprot_t newprot, int dirty_accountable, int prot_numa,
+ bool mkwrite)
{
pmd_t *pmd;
struct mm_struct *mm = vma->vm_mm;
@@ -179,7 +180,7 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
__split_huge_pmd(vma, pmd, addr, false, NULL);
} else {
int nr_ptes = change_huge_pmd(vma, pmd, addr,
- newprot, prot_numa);
+ newprot, prot_numa, mkwrite);

if (nr_ptes) {
if (nr_ptes == HPAGE_PMD_NR) {
@@ -194,7 +195,7 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
/* fall through, the trans huge pmd just split */
}
this_pages = change_pte_range(vma, pmd, addr, next, newprot,
- dirty_accountable, prot_numa);
+ dirty_accountable, prot_numa, mkwrite);
pages += this_pages;
next:
cond_resched();
@@ -210,7 +211,8 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,

static inline unsigned long change_pud_range(struct vm_area_struct *vma,
p4d_t *p4d, unsigned long addr, unsigned long end,
- pgprot_t newprot, int dirty_accountable, int prot_numa)
+ pgprot_t newprot, int dirty_accountable, int prot_numa,
+ bool mkwrite)
{
pud_t *pud;
unsigned long next;
@@ -222,7 +224,7 @@ static inline unsigned long change_pud_range(struct vm_area_struct *vma,
if (pud_none_or_clear_bad(pud))
continue;
pages += change_pmd_range(vma, pud, addr, next, newprot,
- dirty_accountable, prot_numa);
+ dirty_accountable, prot_numa, mkwrite);
} while (pud++, addr = next, addr != end);

return pages;
@@ -230,7 +232,8 @@ static inline unsigned long change_pud_range(struct vm_area_struct *vma,

static inline unsigned long change_p4d_range(struct vm_area_struct *vma,
pgd_t *pgd, unsigned long addr, unsigned long end,
- pgprot_t newprot, int dirty_accountable, int prot_numa)
+ pgprot_t newprot, int dirty_accountable, int prot_numa,
+ bool mkwrite)
{
p4d_t *p4d;
unsigned long next;
@@ -242,7 +245,7 @@ static inline unsigned long change_p4d_range(struct vm_area_struct *vma,
if (p4d_none_or_clear_bad(p4d))
continue;
pages += change_pud_range(vma, p4d, addr, next, newprot,
- dirty_accountable, prot_numa);
+ dirty_accountable, prot_numa, mkwrite);
} while (p4d++, addr = next, addr != end);

return pages;
@@ -250,7 +253,7 @@ static inline unsigned long change_p4d_range(struct vm_area_struct *vma,

static unsigned long change_protection_range(struct vm_area_struct *vma,
unsigned long addr, unsigned long end, pgprot_t newprot,
- int dirty_accountable, int prot_numa)
+ int dirty_accountable, int prot_numa, mkwrite)
{
struct mm_struct *mm = vma->vm_mm;
pgd_t *pgd;
@@ -267,7 +270,7 @@ static unsigned long change_protection_range(struct vm_area_struct *vma,
if (pgd_none_or_clear_bad(pgd))
continue;
pages += change_p4d_range(vma, pgd, addr, next, newprot,
- dirty_accountable, prot_numa);
+ dirty_accountable, prot_numa, mkwrite);
} while (pgd++, addr = next, addr != end);

/* Only flush the TLB if we actually modified any entries: */
@@ -280,14 +283,16 @@ static unsigned long change_protection_range(struct vm_area_struct *vma,

unsigned long change_protection(struct vm_area_struct *vma, unsigned long start,
unsigned long end, pgprot_t newprot,
- int dirty_accountable, int prot_numa)
+ int dirty_accountable, int prot_numa, bool mkwrite)
{
unsigned long pages;

if (is_vm_hugetlb_page(vma))
pages = hugetlb_change_protection(vma, start, end, newprot);
else
- pages = change_protection_range(vma, start, end, newprot, dirty_accountable, prot_numa);
+ pages = change_protection_range(vma, start, end, newprot,
+ dirty_accountable,
+ prot_numa, mkwrite);

return pages;
}
@@ -366,7 +371,7 @@ mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev,
vma_set_page_prot(vma);

change_protection(vma, start, end, vma->vm_page_prot,
- dirty_accountable, 0);
+ dirty_accountable, 0, false);

/*
* Private VM_LOCKED VMA becoming writable: trigger COW to avoid major
diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index a0379c5ffa7c..c745c5d87523 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -632,7 +632,7 @@ int mwriteprotect_range(struct mm_struct *dst_mm, unsigned long start,
newprot = vm_get_page_prot(dst_vma->vm_flags);

change_protection(dst_vma, start, start + len, newprot,
- !enable_wp, 0);
+ 0, 0, !enable_wp);

err = 0;
out_unlock:
--
2.17.1