[PATCH 2/2] vsock: refactor vsock_for_each_connected_socket

From: Jiyong Park
Date: Thu Mar 10 2022 - 07:55:08 EST


vsock_for_each_connected_socket now cycles over sockets of a specific
transport only, rather than asking callers to do the filtering manually,
which is error-prone.

Signed-off-by: Jiyong Park <jiyong@xxxxxxxxxx>
---
drivers/vhost/vsock.c | 7 ++-----
include/net/af_vsock.h | 3 ++-
net/vmw_vsock/af_vsock.c | 9 +++++++--
net/vmw_vsock/virtio_transport.c | 12 ++++--------
net/vmw_vsock/vmci_transport.c | 8 ++------
5 files changed, 17 insertions(+), 22 deletions(-)

diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 853ddac00d5b..e6c9d41db1de 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -722,10 +722,6 @@ static void vhost_vsock_reset_orphans(struct sock *sk)
* executing.
*/

- /* Only handle our own sockets */
- if (vsk->transport != &vhost_transport.transport)
- return;
-
/* If the peer is still valid, no need to reset connection */
if (vhost_vsock_get(vsk->remote_addr.svm_cid))
return;
@@ -757,7 +753,8 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file)

/* Iterating over all connections for all CIDs to find orphans is
* inefficient. Room for improvement here. */
- vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
+ vsock_for_each_connected_socket(&vhost_transport.transport,
+ vhost_vsock_reset_orphans);

/* Don't check the owner, because we are in the release path, so we
* need to stop the vsock device in any case.
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index ab207677e0a8..f742e50207fb 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -205,7 +205,8 @@ struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
struct sockaddr_vm *dst);
void vsock_remove_sock(struct vsock_sock *vsk);
-void vsock_for_each_connected_socket(void (*fn)(struct sock *sk));
+void vsock_for_each_connected_socket(struct vsock_transport *transport,
+ void (*fn)(struct sock *sk));
int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk);
bool vsock_find_cid(unsigned int cid);

diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 38baeb189d4e..f04abf662ec6 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -334,7 +334,8 @@ void vsock_remove_sock(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(vsock_remove_sock);

-void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
+void vsock_for_each_connected_socket(struct vsock_transport *transport,
+ void (*fn)(struct sock *sk))
{
int i;

@@ -343,8 +344,12 @@ void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) {
struct vsock_sock *vsk;
list_for_each_entry(vsk, &vsock_connected_table[i],
- connected_table)
+ connected_table) {
+ if (vsk->transport != transport)
+ continue;
+
fn(sk_vsock(vsk));
+ }
}

spin_unlock_bh(&vsock_table_lock);
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 61b24eb31d4b..5afc194a58bb 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -358,17 +358,11 @@ static void virtio_vsock_event_fill(struct virtio_vsock *vsock)

static void virtio_vsock_reset_sock(struct sock *sk)
{
- struct vsock_sock *vsk = vsock_sk(sk);
-
/* vmci_transport.c doesn't take sk_lock here either. At least we're
* under vsock_table_lock so the sock cannot disappear while we're
* executing.
*/

- /* Only handle our own sockets */
- if (vsk->transport != &virtio_transport.transport)
- return;
-
sk->sk_state = TCP_CLOSE;
sk->sk_err = ECONNRESET;
sk_error_report(sk);
@@ -391,7 +385,8 @@ static void virtio_vsock_event_handle(struct virtio_vsock *vsock,
switch (le32_to_cpu(event->id)) {
case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET:
virtio_vsock_update_guest_cid(vsock);
- vsock_for_each_connected_socket(virtio_vsock_reset_sock);
+ vsock_for_each_connected_socket(&virtio_transport.transport,
+ virtio_vsock_reset_sock);
break;
}
}
@@ -669,7 +664,8 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
synchronize_rcu();

/* Reset all connected sockets when the device disappear */
- vsock_for_each_connected_socket(virtio_vsock_reset_sock);
+ vsock_for_each_connected_socket(&virtio_transport.transport,
+ virtio_vsock_reset_sock);

/* Stop all work handlers to make sure no one is accessing the device,
* so we can safely call virtio_reset_device().
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index cd2f01513fae..735d5e14608a 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -803,11 +803,6 @@ static void vmci_transport_handle_detach(struct sock *sk)
struct vsock_sock *vsk;

vsk = vsock_sk(sk);
-
- /* Only handle our own sockets */
- if (vsk->transport != &vmci_transport)
- return;
-
if (!vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)) {
sock_set_flag(sk, SOCK_DONE);

@@ -887,7 +882,8 @@ static void vmci_transport_qp_resumed_cb(u32 sub_id,
const struct vmci_event_data *e_data,
void *client_data)
{
- vsock_for_each_connected_socket(vmci_transport_handle_detach);
+ vsock_for_each_connected_socket(&vmci_transport,
+ vmci_transport_handle_detach);
}

static void vmci_transport_recv_pkt_work(struct work_struct *work)
--
2.35.1.723.g4982287a31-goog