[PATCH v3] mm/gup: Allow real explicit breaking of COW
From: Peter Xu
Date: Tue Aug 11 2020 - 14:40:04 EST
Starting from commit 17839856fd58 ("gup: document and work around "COW can
break either way" issue", 2020-06-02), explicit copy-on-write behavior is
enforced for private gup pages even if it's a read-only. It is achieved by
always passing FOLL_WRITE to emulate a write.
That should fix the COW issue that we were facing, however above commit could
also break userfaultfd-wp and applications like umapsort [1,2].
One general routine of umap-like program is: userspace library will manage page
allocations, and it will evict the least recently used pages from memory to
external storages (e.g., file systems). Below are the general steps to evict
an in-memory page in the uffd service thread when the page pool is full:
(1) UFFDIO_WRITEPROTECT with mode=WP on some to-be-evicted page P, so that
further writes to page P will block (keep page P clean)
(2) Copy page P to external storage (e.g. file system)
(3) MADV_DONTNEED to evict page P
Here step (1) makes sure that the page to dump will always be up-to-date, so
that the page snapshot in the file system is consistent with the one that was
in the memory. However with commit 17839856fd58, step (2) can potentially hang
itself because e.g. if we use write() to a file system fd to dump the page
data, that will be a translated read gup request in the file system driver to
read the page content, then the read gup will be translated to a write gup due
to the new enforced COW behavior. This write gup will further trigger
handle_userfault() and hang the uffd service thread itself.
I think the problem will go away too if we replace the write() to the file
system into a memory write to a mmaped region in the userspace library, because
normal page faults will not enforce COW, only gup is affected. However we
cannot forbid users to use write() or any form of kernel level read gup.
One solution is actually already mentioned in commit 17839856fd58, which is to
provide an explicit BREAK_COW scemantics for enforced COW. Then we can still
use FAULT_FLAG_WRITE to identify whether this is a "real write request" or an
"enfornced COW (read) request".
With the enforced COW, we also need to inherit UFFD_WP bit during COW because
now COW can happen with UFFD_WP enabled (previously, it cannot).
Since at it, rename the variable in __handle_mm_fault() from "dirty" to "cow"
to better suite its functionality.
[1] https://github.com/LLNL/umap-apps/blob/develop/src/umapsort/umapsort.cpp
[2] https://github.com/LLNL/umap
CC: Marty Mcfadden <mcfadden8@xxxxxxxx>
CC: Maya B. Gokhale <gokhale2@xxxxxxxx>
CC: Andrea Arcangeli <aarcange@xxxxxxxxxx>
CC: Linus Torvalds <torvalds@xxxxxxxxxxxxxxxxxxxx>
CC: Andrew Morton <akpm@xxxxxxxxxxxxxxxxxxxx>
CC: Jann Horn <jannh@xxxxxxxxxx>
CC: Christoph Hellwig <hch@xxxxxx>
CC: Oleg Nesterov <oleg@xxxxxxxxxx>
CC: Kirill Shutemov <kirill@xxxxxxxxxxxxx>
CC: Jan Kara <jack@xxxxxxx>
Fixes: 17839856fd588f4ab6b789f482ed3ffd7c403e1f
Signed-off-by: Peter Xu <peterx@xxxxxxxxxx>
---
v3:
- inherit UFFD_WP bit for COW too
- take care of huge page cases
- more comments
v2:
- apply FAULT_FLAG_BREAK_COW correctly when FOLL_BREAK_COW [Christoph]
- removed comments above do_wp_page which seems redundant
---
include/linux/mm.h | 3 +++
mm/gup.c | 6 ++++--
mm/huge_memory.c | 12 +++++++++++-
mm/memory.c | 39 +++++++++++++++++++++++++++++++--------
4 files changed, 49 insertions(+), 11 deletions(-)
diff --git a/include/linux/mm.h b/include/linux/mm.h
index f6a82f9bccd7..a1f5c92b44cb 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -409,6 +409,7 @@ extern pgprot_t protection_map[16];
* @FAULT_FLAG_REMOTE: The fault is not for current task/mm.
* @FAULT_FLAG_INSTRUCTION: The fault was during an instruction fetch.
* @FAULT_FLAG_INTERRUPTIBLE: The fault can be interrupted by non-fatal signals.
+ * @FAULT_FLAG_BREAK_COW: Do COW explicitly for the fault (even for read).
*
* About @FAULT_FLAG_ALLOW_RETRY and @FAULT_FLAG_TRIED: we can specify
* whether we would allow page faults to retry by specifying these two
@@ -439,6 +440,7 @@ extern pgprot_t protection_map[16];
#define FAULT_FLAG_REMOTE 0x80
#define FAULT_FLAG_INSTRUCTION 0x100
#define FAULT_FLAG_INTERRUPTIBLE 0x200
+#define FAULT_FLAG_BREAK_COW 0x400
/*
* The default fault flags that should be used by most of the
@@ -2756,6 +2758,7 @@ struct page *follow_page(struct vm_area_struct *vma, unsigned long address,
#define FOLL_SPLIT_PMD 0x20000 /* split huge pmd before returning */
#define FOLL_PIN 0x40000 /* pages must be released via unpin_user_page */
#define FOLL_FAST_ONLY 0x80000 /* gup_fast: prevent fall-back to slow gup */
+#define FOLL_BREAK_COW 0x100000 /* request for explicit COW (even for read) */
/*
* FOLL_PIN and FOLL_LONGTERM may be used in various combinations with each
diff --git a/mm/gup.c b/mm/gup.c
index d8a33dd1430d..c33e84ab9c36 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -870,6 +870,8 @@ static int faultin_page(struct task_struct *tsk, struct vm_area_struct *vma,
return -ENOENT;
if (*flags & FOLL_WRITE)
fault_flags |= FAULT_FLAG_WRITE;
+ if (*flags & FOLL_BREAK_COW)
+ fault_flags |= FAULT_FLAG_BREAK_COW;
if (*flags & FOLL_REMOTE)
fault_flags |= FAULT_FLAG_REMOTE;
if (locked)
@@ -1076,7 +1078,7 @@ static long __get_user_pages(struct task_struct *tsk, struct mm_struct *mm,
}
if (is_vm_hugetlb_page(vma)) {
if (should_force_cow_break(vma, foll_flags))
- foll_flags |= FOLL_WRITE;
+ foll_flags |= FOLL_BREAK_COW;
i = follow_hugetlb_page(mm, vma, pages, vmas,
&start, &nr_pages, i,
foll_flags, locked);
@@ -1095,7 +1097,7 @@ static long __get_user_pages(struct task_struct *tsk, struct mm_struct *mm,
}
if (should_force_cow_break(vma, foll_flags))
- foll_flags |= FOLL_WRITE;
+ foll_flags |= FOLL_BREAK_COW;
retry:
/*
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 206f52b36ffb..c88f773d03af 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -1296,7 +1296,17 @@ vm_fault_t do_huge_pmd_wp_page(struct vm_fault *vmf, pmd_t orig_pmd)
if (reuse_swap_page(page, NULL)) {
pmd_t entry;
entry = pmd_mkyoung(orig_pmd);
- entry = maybe_pmd_mkwrite(pmd_mkdirty(entry), vma);
+ entry = pmd_mkdirty(entry);
+ if (pmd_uffd_wp(orig_pmd))
+ /*
+ * This can happen when an uffd-wp protected page is
+ * copied due to enfornced COW. When it happens, we
+ * need to keep the uffd-wp bit even after COW, and
+ * make sure write bit is kept cleared.
+ */
+ entry = pmd_mkuffd_wp(pmd_wrprotect(entry));
+ else
+ entry = maybe_pmd_mkwrite(entry, vma);
if (pmdp_set_access_flags(vma, haddr, vmf->pmd, entry, 1))
update_mmu_cache_pmd(vma, vmf->address, vmf->pmd);
unlock_page(page);
diff --git a/mm/memory.c b/mm/memory.c
index c39a13b09602..b27b555a9df8 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -2706,7 +2706,17 @@ static vm_fault_t wp_page_copy(struct vm_fault *vmf)
flush_cache_page(vma, vmf->address, pte_pfn(vmf->orig_pte));
entry = mk_pte(new_page, vma->vm_page_prot);
entry = pte_sw_mkyoung(entry);
- entry = maybe_mkwrite(pte_mkdirty(entry), vma);
+ entry = pte_mkdirty(entry);
+ if (pte_uffd_wp(vmf->orig_pte))
+ /*
+ * This can happen when an uffd-wp protected page is
+ * copied due to enfornced COW. When it happens, we
+ * need to keep the uffd-wp bit even after COW, and
+ * make sure write bit is kept cleared.
+ */
+ entry = pte_mkuffd_wp(pte_wrprotect(entry));
+ else
+ entry = maybe_mkwrite(entry, vma);
/*
* Clear the pte entry and flush it first, before updating the
* pte with the new entry. This will avoid a race condition
@@ -2900,7 +2910,13 @@ static vm_fault_t do_wp_page(struct vm_fault *vmf)
{
struct vm_area_struct *vma = vmf->vma;
- if (userfaultfd_pte_wp(vma, *vmf->pte)) {
+ /*
+ * Userfaultfd-wp only cares about real writes. E.g., enforced COW for
+ * read does not count. When that happens, we will do the COW with the
+ * UFFD_WP bit inherited from the original PTE/PMD.
+ */
+ if ((vmf->flags & FAULT_FLAG_WRITE) &&
+ userfaultfd_pte_wp(vma, *vmf->pte)) {
pte_unmap_unlock(vmf->pte, vmf->ptl);
return handle_userfault(vmf, VM_UFFD_WP);
}
@@ -3290,7 +3306,7 @@ vm_fault_t do_swap_page(struct vm_fault *vmf)
put_page(swapcache);
}
- if (vmf->flags & FAULT_FLAG_WRITE) {
+ if (vmf->flags & (FAULT_FLAG_WRITE | FAULT_FLAG_BREAK_COW)) {
ret |= do_wp_page(vmf);
if (ret & VM_FAULT_ERROR)
ret &= VM_FAULT_ERROR;
@@ -4117,7 +4133,14 @@ static inline vm_fault_t create_huge_pmd(struct vm_fault *vmf)
static inline vm_fault_t wp_huge_pmd(struct vm_fault *vmf, pmd_t orig_pmd)
{
if (vma_is_anonymous(vmf->vma)) {
- if (userfaultfd_huge_pmd_wp(vmf->vma, orig_pmd))
+ /*
+ * Userfaultfd-wp only cares about real writes. E.g., enforced
+ * COW for read does not count. When that happens, we will do
+ * the COW with the UFFD_WP bit inherited from the original
+ * PTE/PMD.
+ */
+ if ((vmf->flags & FAULT_FLAG_WRITE) &&
+ userfaultfd_huge_pmd_wp(vmf->vma, orig_pmd))
return handle_userfault(vmf, VM_UFFD_WP);
return do_huge_pmd_wp_page(vmf, orig_pmd);
}
@@ -4241,7 +4264,7 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
update_mmu_tlb(vmf->vma, vmf->address, vmf->pte);
goto unlock;
}
- if (vmf->flags & FAULT_FLAG_WRITE) {
+ if (vmf->flags & (FAULT_FLAG_WRITE | FAULT_FLAG_BREAK_COW)) {
if (!pte_write(entry))
return do_wp_page(vmf);
entry = pte_mkdirty(entry);
@@ -4281,7 +4304,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
.pgoff = linear_page_index(vma, address),
.gfp_mask = __get_fault_gfp_mask(vma),
};
- unsigned int dirty = flags & FAULT_FLAG_WRITE;
+ bool cow = flags & (FAULT_FLAG_WRITE | FAULT_FLAG_BREAK_COW);
struct mm_struct *mm = vma->vm_mm;
pgd_t *pgd;
p4d_t *p4d;
@@ -4308,7 +4331,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
/* NUMA case for anonymous PUDs would go here */
- if (dirty && !pud_write(orig_pud)) {
+ if (cow && !pud_write(orig_pud)) {
ret = wp_huge_pud(&vmf, orig_pud);
if (!(ret & VM_FAULT_FALLBACK))
return ret;
@@ -4346,7 +4369,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
if (pmd_protnone(orig_pmd) && vma_is_accessible(vma))
return do_huge_pmd_numa_page(&vmf, orig_pmd);
- if (dirty && !pmd_write(orig_pmd)) {
+ if (cow && !pmd_write(orig_pmd)) {
ret = wp_huge_pmd(&vmf, orig_pmd);
if (!(ret & VM_FAULT_FALLBACK))
return ret;
--
2.26.2