[RFC V2 35/37] mm, dmem: introduce dregion->memmap for dmem

From: yulei . kernel
Date: Mon Dec 07 2020 - 06:36:06 EST


From: Yulei Zhang <yuleixzhang@xxxxxxxxxxx>

Append 'memmap' into struct dmem_region, mapping each page of dmem with
struct dmempage.

Currently there is just one member '_refcount' in struct dmempage to
reflect the number of all modules which occupied the dmem page.

Modules which allocates the dmem page from dmempool will make first
reference and set _refcount to 1.

Modules which try to free the dmem page to dmempool will decrease 1
at _refcount and free it if _refcount is tested as zero after decrease.

At each time module A passes dmem page to module B, module B should call
get_dmem_pfn() to increase _refcount for dmem page before making use of it
to avoid referencing a dmem page which is occasionally freeed by any other
module in parallel. Vice versa after finishing usage of that dmem page
need call put_dmem_pfn() to decrease the _refcount.

Signed-off-by: Chen Zhuo <sagazchen@xxxxxxxxxxx>
Signed-off-by: Yulei Zhang <yuleixzhang@xxxxxxxxxxx>
---
include/linux/dmem.h | 5 ++
mm/dmem.c | 147 ++++++++++++++++++++++++++++++++++++++++++++++-----
2 files changed, 139 insertions(+), 13 deletions(-)

diff --git a/include/linux/dmem.h b/include/linux/dmem.h
index fe0b270..8aaa80b 100644
--- a/include/linux/dmem.h
+++ b/include/linux/dmem.h
@@ -22,6 +22,9 @@
bool is_dmem_pfn(unsigned long pfn);
#define dmem_free_page(addr) dmem_free_pages(addr, 1)

+void get_dmem_pfn(unsigned long pfn);
+#define put_dmem_pfn(pfn) dmem_free_page(PFN_PHYS(pfn))
+
bool dmem_memory_failure(unsigned long pfn, int flags);

struct dmem_mce_notifier_info {
@@ -45,5 +48,7 @@ static inline bool dmem_memory_failure(unsigned long pfn, int flags)
{
return false;
}
+void get_dmem_pfn(unsigned long pfn) {}
+void put_dmem_pfn(unsigned long pfn) {}
#endif
#endif /* _LINUX_DMEM_H */
diff --git a/mm/dmem.c b/mm/dmem.c
index dd81b24..776dbf2 100644
--- a/mm/dmem.c
+++ b/mm/dmem.c
@@ -47,6 +47,7 @@ struct dmem_region {

unsigned long static_error_bitmap;
unsigned long *error_bitmap;
+ void *memmap;
};

/*
@@ -91,6 +92,10 @@ struct dmem_pool {
struct dmem_node nodes[MAX_NUMNODES];
};

+struct dmempage {
+ atomic_t _refcount;
+};
+
static struct dmem_pool dmem_pool = {
.lock = __MUTEX_INITIALIZER(dmem_pool.lock),
.mce_notifier_chain = RAW_NOTIFIER_INIT(dmem_pool.mce_notifier_chain),
@@ -123,6 +128,40 @@ struct dmem_pool {
#define for_each_dmem_region(_dnode, _dregion) \
list_for_each_entry(_dregion, &(_dnode)->regions, node)

+#define pfn_to_dmempage(_pfn, _dregion) \
+ ((struct dmempage *)(_dregion)->memmap + \
+ pfn_to_dpage(_pfn) - (_dregion)->dpage_start_pfn)
+
+#define dmempage_to_dpage(_dmempage, _dregion) \
+ ((_dmempage) - (struct dmempage *)(_dregion)->memmap + \
+ (_dregion)->dpage_start_pfn)
+
+static inline int dmempage_count(struct dmempage *dmempage)
+{
+ return atomic_read(&dmempage->_refcount);
+}
+
+static inline void set_dmempage_count(struct dmempage *dmempage, int v)
+{
+ atomic_set(&dmempage->_refcount, v);
+}
+
+static inline void dmempage_ref_inc(struct dmempage *dmempage)
+{
+ atomic_inc(&dmempage->_refcount);
+}
+
+static inline int dmempage_ref_dec_and_test(struct dmempage *dmempage)
+{
+ return atomic_dec_and_test(&dmempage->_refcount);
+}
+
+static inline int put_dmempage_testzero(struct dmempage *dmempage)
+{
+ VM_BUG_ON(dmempage_count(dmempage) == 0);
+ return dmempage_ref_dec_and_test(dmempage);
+}
+
int dmem_register_mce_notifier(struct notifier_block *nb)
{
int ret;
@@ -559,10 +598,25 @@ static int __init dmem_late_init(void)
}
late_initcall(dmem_late_init);

+static void *dmem_memmap_alloc(unsigned long dpages)
+{
+ unsigned long size;
+
+ size = dpages * sizeof(struct dmempage);
+ return vzalloc(size);
+}
+
+static void dmem_memmap_free(void *memmap)
+{
+ if (memmap)
+ vfree(memmap);
+}
+
static int dmem_alloc_region_init(struct dmem_region *dregion,
unsigned long *dpages)
{
unsigned long start, end, *bitmap;
+ void *memmap;

start = DMEM_PAGE_UP(dregion->reserved_start_addr);
end = DMEM_PAGE_DOWN(dregion->reserved_end_addr);
@@ -575,7 +629,14 @@ static int dmem_alloc_region_init(struct dmem_region *dregion,
if (!bitmap)
return -ENOMEM;

+ memmap = dmem_memmap_alloc(*dpages);
+ if (!memmap) {
+ dmem_bitmap_free(*dpages, bitmap, &dregion->static_bitmap);
+ return -ENOMEM;
+ }
+
dregion->bitmap = bitmap;
+ dregion->memmap = memmap;
dregion->next_free_pos = 0;
dregion->dpage_start_pfn = start;
dregion->dpage_end_pfn = end;
@@ -650,7 +711,9 @@ static void dmem_alloc_region_uinit(struct dmem_region *dregion)
dmem_uinit_check_alloc_bitmap(dregion);

dmem_bitmap_free(dpages, bitmap, &dregion->static_bitmap);
+ dmem_memmap_free(dregion->memmap);
dregion->bitmap = NULL;
+ dregion->memmap = NULL;
}

static void __dmem_alloc_uinit(void)
@@ -793,6 +856,16 @@ int dmem_alloc_init(unsigned long dpage_shift)
return dpage_to_phys(dregion->dpage_start_pfn + pos);
}

+static void prep_new_dmempage(unsigned long phys, unsigned int nr,
+ struct dmem_region *dregion)
+{
+ struct dmempage *dmempage = pfn_to_dmempage(PHYS_PFN(phys), dregion);
+ unsigned int i;
+
+ for (i = 0; i < nr; i++, dmempage++)
+ set_dmempage_count(dmempage, 1);
+}
+
/*
* allocate dmem pages from the nodelist
*
@@ -839,6 +912,7 @@ int dmem_alloc_init(unsigned long dpage_shift)
if (addr) {
dnode_count_free_dpages(dnode,
-(long)(*result_nr));
+ prep_new_dmempage(addr, *result_nr, dregion);
break;
}
}
@@ -993,6 +1067,41 @@ static struct dmem_region *find_dmem_region(phys_addr_t phys_addr,
return NULL;
}

+static unsigned int free_dmempages_prepare(struct dmempage *dmempage,
+ unsigned int dpages_nr)
+{
+ unsigned int i, ret = 0;
+
+ for (i = 0; i < dpages_nr; i++, dmempage++)
+ if (put_dmempage_testzero(dmempage))
+ ret++;
+
+ return ret;
+}
+
+void __dmem_free_pages(struct dmempage *dmempage,
+ unsigned int dpages_nr,
+ struct dmem_region *dregion,
+ struct dmem_node *pdnode)
+{
+ phys_addr_t dpage = dmempage_to_dpage(dmempage, dregion);
+ u64 pos;
+ unsigned long err_dpages;
+
+ trace_dmem_free_pages(dpage_to_phys(dpage), dpages_nr);
+ WARN_ON(!dmem_pool.dpage_shift);
+
+ pos = dpage - dregion->dpage_start_pfn;
+ dregion->next_free_pos = min(dregion->next_free_pos, pos);
+
+ /* it is not possible to span multiple regions */
+ WARN_ON(dpage + dpages_nr - 1 >= dregion->dpage_end_pfn);
+
+ err_dpages = dmem_alloc_bitmap_clear(dregion, dpage, dpages_nr);
+
+ dnode_count_free_dpages(pdnode, dpages_nr - err_dpages);
+}
+
/*
* free dmem page to the dmem pool
* @addr: the physical addree will be freed
@@ -1002,27 +1111,26 @@ void dmem_free_pages(phys_addr_t addr, unsigned int dpages_nr)
{
struct dmem_region *dregion;
struct dmem_node *pdnode = NULL;
- phys_addr_t dpage = phys_to_dpage(addr);
- u64 pos;
- unsigned long err_dpages;
+ struct dmempage *dmempage;
+ unsigned int nr;

mutex_lock(&dmem_pool.lock);

- trace_dmem_free_pages(addr, dpages_nr);
- WARN_ON(!dmem_pool.dpage_shift);
-
dregion = find_dmem_region(addr, &pdnode);
WARN_ON(!dregion || !dregion->bitmap || !pdnode);

- pos = dpage - dregion->dpage_start_pfn;
- dregion->next_free_pos = min(dregion->next_free_pos, pos);
-
- /* it is not possible to span multiple regions */
- WARN_ON(dpage + dpages_nr - 1 >= dregion->dpage_end_pfn);
+ dmempage = pfn_to_dmempage(PHYS_PFN(addr), dregion);

- err_dpages = dmem_alloc_bitmap_clear(dregion, dpage, dpages_nr);
+ nr = free_dmempages_prepare(dmempage, dpages_nr);
+ if (nr == dpages_nr)
+ __dmem_free_pages(dmempage, dpages_nr, dregion, pdnode);
+ else if (nr)
+ while (dpages_nr--, dmempage++) {
+ if (dmempage_count(dmempage))
+ continue;
+ __dmem_free_pages(dmempage, 1, dregion, pdnode);
+ }

- dnode_count_free_dpages(pdnode, dpages_nr - err_dpages);
mutex_unlock(&dmem_pool.lock);
}
EXPORT_SYMBOL(dmem_free_pages);
@@ -1073,3 +1181,16 @@ bool is_dmem_pfn(unsigned long pfn)
return !!find_dmem_region(__pfn_to_phys(pfn), &dnode);
}
EXPORT_SYMBOL(is_dmem_pfn);
+
+void get_dmem_pfn(unsigned long pfn)
+{
+ struct dmem_region *dregion = find_dmem_region(PFN_PHYS(pfn), NULL);
+ struct dmempage *dmempage;
+
+ VM_BUG_ON(!dregion || !dregion->memmap);
+
+ dmempage = pfn_to_dmempage(pfn, dregion);
+ VM_BUG_ON(dmempage_count(dmempage) + 127u <= 127u);
+ dmempage_ref_inc(dmempage);
+}
+EXPORT_SYMBOL(get_dmem_pfn);
--
1.8.3.1