[PATCH v3 1/2] vhost: Add rbtree vdpa_mem_tree to saved the counted mem

From: Cindy Lu
Date: Sun Jun 26 2022 - 05:04:28 EST


We count pinned_vm as follow in vhost-vDPA

lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
ret = -ENOMEM;
goto unlock;
}
This means if we have two vDPA devices for the same VM the pages
would be counted twice. So we add a tree to save the page that
counted and we will not count it again.

Add vdpa_mem_tree to saved the mem that already counted.
use a hlist to saved the root for vdpa_mem_tree.

Signed-off-by: Cindy Lu <lulu@xxxxxxxxxx>
---
drivers/vhost/vhost.c | 63 +++++++++++++++++++++++++++++++++++++++++++
drivers/vhost/vhost.h | 1 +
2 files changed, 64 insertions(+)

diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 40097826cff0..4ca8b1ed944b 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -32,6 +32,8 @@
#include <linux/kcov.h>

#include "vhost.h"
+#include <linux/hashtable.h>
+#include <linux/jhash.h>

static ushort max_mem_regions = 64;
module_param(max_mem_regions, ushort, 0444);
@@ -49,6 +51,14 @@ enum {
#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])

+struct vhost_vdpa_rbtree_node {
+ struct hlist_node node;
+ struct rb_root_cached vdpa_mem_tree;
+ struct mm_struct *mm_using;
+};
+static DECLARE_HASHTABLE(vhost_vdpa_rbtree_hlist, 8);
+int vhost_vdpa_rbtree_hlist_status;
+
#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
{
@@ -571,6 +581,51 @@ static void vhost_detach_mm(struct vhost_dev *dev)
dev->mm = NULL;
}

+struct rb_root_cached *vhost_vdpa_get_mem_tree(struct mm_struct *mm)
+{
+ struct vhost_vdpa_rbtree_node *rbtree_root = NULL;
+ struct rb_root_cached *vdpa_tree;
+ u32 key;
+
+ /* No hased table, init one */
+ if (vhost_vdpa_rbtree_hlist_status == 0) {
+ hash_init(vhost_vdpa_rbtree_hlist);
+ vhost_vdpa_rbtree_hlist_status = 1;
+ }
+
+ key = jhash_1word((u64)mm, JHASH_INITVAL);
+ hash_for_each_possible(vhost_vdpa_rbtree_hlist, rbtree_root, node,
+ key) {
+ if (rbtree_root->mm_using == mm)
+ return &(rbtree_root->vdpa_mem_tree);
+ }
+ rbtree_root = kmalloc(sizeof(*rbtree_root), GFP_KERNEL);
+ if (!rbtree_root)
+ return NULL;
+ rbtree_root->mm_using = mm;
+ rbtree_root->vdpa_mem_tree = RB_ROOT_CACHED;
+ hash_add(vhost_vdpa_rbtree_hlist, &rbtree_root->node, key);
+ vdpa_tree = &(rbtree_root->vdpa_mem_tree);
+ return vdpa_tree;
+}
+
+void vhost_vdpa_relase_mem_tree(struct mm_struct *mm)
+{
+ struct vhost_vdpa_rbtree_node *rbtree_root = NULL;
+ u32 key;
+
+ key = jhash_1word((u64)mm, JHASH_INITVAL);
+
+ /* No hased table, init one */
+ hash_for_each_possible(vhost_vdpa_rbtree_hlist, rbtree_root, node,
+ key) {
+ if (rbtree_root->mm_using == mm) {
+ hash_del(&rbtree_root->node);
+ kfree(rbtree_root);
+ }
+ }
+}
+
/* Caller should have device mutex */
long vhost_dev_set_owner(struct vhost_dev *dev)
{
@@ -605,6 +660,11 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
err = vhost_dev_alloc_iovecs(dev);
if (err)
goto err_cgroup;
+ dev->vdpa_mem_tree = vhost_vdpa_get_mem_tree(dev->mm);
+ if (dev->vdpa_mem_tree == NULL) {
+ err = -ENOMEM;
+ goto err_cgroup;
+ }

return 0;
err_cgroup:
@@ -613,6 +673,7 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
dev->worker = NULL;
}
err_worker:
+ vhost_vdpa_relase_mem_tree(dev->mm);
vhost_detach_mm(dev);
dev->kcov_handle = 0;
err_mm:
@@ -710,6 +771,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
dev->worker = NULL;
dev->kcov_handle = 0;
}
+
+ vhost_vdpa_relase_mem_tree(dev->mm);
vhost_detach_mm(dev);
}
EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index d9109107af08..84de33de3abf 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -160,6 +160,7 @@ struct vhost_dev {
int byte_weight;
u64 kcov_handle;
bool use_worker;
+ struct rb_root_cached *vdpa_mem_tree;
int (*msg_handler)(struct vhost_dev *dev, u32 asid,
struct vhost_iotlb_msg *msg);
};
--
2.34.3