[RFC PATCH net-next v6 03/14] af_vsock: support multi-transport datagrams
From: Amery Hung
Date: Wed Jul 10 2024 - 17:27:44 EST
From: Bobby Eshleman <bobby.eshleman@xxxxxxxxxxxxx>
This patch adds support for multi-transport datagrams.
This includes:
- Allow transport to be undecided (i.e., empty) for non-VMCI datagram
use cases during socket creation.
- connect() now assigns the transport for (similar to connectible
sockets)
- Per-packet lookup of transports when using sendto(sockaddr_vm)
- Selecting H2G or G2H transport using VMADDR_FLAG_TO_HOST and CID in
sockaddr_vm
- Rename VSOCK_TRANSPORT_F_DGRAM to VSOCK_TRANSPORT_F_DGRAM_FALLBACK
* Dynamic transport lookup *
virtio datagram will follow h2g/g2h paradigm. Since it is impossible
to know which transport to use during socket creation, the transport is
allowed to remain empty. The transport will be assigned only when
connect() is called. Otherwise, in the sendmsg() path, if sendto() is
used, the cid is used to lookup the transport that will be used. In the
recvmsg() path, since the receiving method is generalized and shared by
different transport, there is now no need to resolve the transport.
Finally, a couple of checks for empty transport are added in other paths
to prevent null-pointer dereference.
* Compatibiliity with VMCI *
To preserve backwards compatibility with VMCI, some important changes
are made. The "transport_dgram" / VSOCK_TRANSPORT_F_DGRAM is changed to
be used for dgrams only if there is not yet a g2h or h2g transport that
has been registered that can transmit the packet. If there is a g2h/h2g
transport for that remote address, then that transport will be used and
not "transport_dgram". This essentially makes "transport_dgram" a
fallback transport for when h2g/g2h has not yet gone online, and so it
is renamed "transport_dgram_fallback". VMCI implements this transport.
The logic around "transport_dgram" needs to be retained to prevent
breaking VMCI:
1) VMCI datagrams existed prior to h2g/g2h and so operate under a
different paradigm. When the vmci transport comes online, it registers
itself with the DGRAM feature, but not H2G/G2H. Only later when the
transport has more information about its environment does it register
H2G or G2H. In the case that a datagram socket is created after
VSOCK_TRANSPORT_F_DGRAM registration but before G2H/H2G registration,
the "transport_dgram" transport is the only registered transport and so
needs to be used.
2) VMCI seems to require a special message be sent by the transport when a
datagram socket calls bind(). Under the h2g/g2h model, the transport
is selected using the remote_addr which is set by connect(). At
bind time there is no remote_addr because often no connect() has been
called yet: the transport is null. Therefore, with a null transport
there doesn't seem to be any good way for a datagram socket to tell the
VMCI transport that it has just had bind() called upon it.
With the new fallback logic, after H2G/G2H comes online the socket layer
will access the VMCI transport via transport_{h2g,g2h}. Prior to H2G/G2H
coming online, the socket layer will access the VMCI transport via
"transport_dgram_fallback".
Only transports with a special datagram fallback use-case such as VMCI
need to register VSOCK_TRANSPORT_F_DGRAM_FALLBACK.
Signed-off-by: Bobby Eshleman <bobby.eshleman@xxxxxxxxxxxxx>
Signed-off-by: Amery Hung <amery.hung@xxxxxxxxxxxxx>
---
drivers/vhost/vsock.c | 1 -
include/linux/virtio_vsock.h | 2 -
include/net/af_vsock.h | 10 +-
net/vmw_vsock/af_vsock.c | 127 +++++++++++++++++++-----
net/vmw_vsock/hyperv_transport.c | 6 --
net/vmw_vsock/virtio_transport.c | 1 -
net/vmw_vsock/virtio_transport_common.c | 7 --
net/vmw_vsock/vmci_transport.c | 2 +-
net/vmw_vsock/vsock_loopback.c | 1 -
9 files changed, 107 insertions(+), 50 deletions(-)
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 97fffa914e66..fa1aefb78016 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -419,7 +419,6 @@ static struct virtio_transport vhost_transport = {
.cancel_pkt = vhost_transport_cancel_pkt,
.dgram_enqueue = virtio_transport_dgram_enqueue,
- .dgram_bind = virtio_transport_dgram_bind,
.dgram_allow = virtio_transport_dgram_allow,
.stream_enqueue = virtio_transport_stream_enqueue,
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index 8b56b8a19ddd..f749a066af46 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -221,8 +221,6 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val);
u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
bool virtio_transport_stream_allow(u32 cid, u32 port);
-int virtio_transport_dgram_bind(struct vsock_sock *vsk,
- struct sockaddr_vm *addr);
bool virtio_transport_dgram_allow(u32 cid, u32 port);
int virtio_transport_connect(struct vsock_sock *vsk);
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index 7aa1f5f2b1a5..44db8f2c507d 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -96,13 +96,13 @@ struct vsock_transport_send_notify_data {
/* Transport features flags */
/* Transport provides host->guest communication */
-#define VSOCK_TRANSPORT_F_H2G 0x00000001
+#define VSOCK_TRANSPORT_F_H2G 0x00000001
/* Transport provides guest->host communication */
-#define VSOCK_TRANSPORT_F_G2H 0x00000002
-/* Transport provides DGRAM communication */
-#define VSOCK_TRANSPORT_F_DGRAM 0x00000004
+#define VSOCK_TRANSPORT_F_G2H 0x00000002
+/* Transport provides fallback for DGRAM communication */
+#define VSOCK_TRANSPORT_F_DGRAM_FALLBACK 0x00000004
/* Transport provides local (loopback) communication */
-#define VSOCK_TRANSPORT_F_LOCAL 0x00000008
+#define VSOCK_TRANSPORT_F_LOCAL 0x00000008
struct vsock_transport {
struct module *module;
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 98d10cd30483..acc15e11700c 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -140,8 +140,8 @@ struct proto vsock_proto = {
static const struct vsock_transport *transport_h2g;
/* Transport used for guest->host communication */
static const struct vsock_transport *transport_g2h;
-/* Transport used for DGRAM communication */
-static const struct vsock_transport *transport_dgram;
+/* Transport used as a fallback for DGRAM communication */
+static const struct vsock_transport *transport_dgram_fallback;
/* Transport used for local communication */
static const struct vsock_transport *transport_local;
static DEFINE_MUTEX(vsock_register_mutex);
@@ -440,19 +440,20 @@ vsock_connectible_lookup_transport(unsigned int cid, __u8 flags)
return transport;
}
-/* Assign a transport to a socket and call the .init transport callback.
- *
- * Note: for connection oriented socket this must be called when vsk->remote_addr
- * is set (e.g. during the connect() or when a connection request on a listener
- * socket is received).
- * The vsk->remote_addr is used to decide which transport to use:
- * - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
- * g2h is not loaded, will use local transport;
- * - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field
- * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport;
- * - remote CID > VMADDR_CID_HOST will use host->guest transport;
- */
-int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+static const struct vsock_transport *
+vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
+{
+ const struct vsock_transport *transport;
+
+ transport = vsock_connectible_lookup_transport(cid, flags);
+ if (transport)
+ return transport;
+
+ return transport_dgram_fallback;
+}
+
+static int __vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk,
+ bool create_sock)
{
const struct vsock_transport *new_transport;
struct sock *sk = sk_vsock(vsk);
@@ -476,7 +477,21 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
switch (sk->sk_type) {
case SOCK_DGRAM:
- new_transport = transport_dgram;
+ /* During vsock_create(), the transport cannot be decided yet if
+ * using virtio. While for VMCI, it is transport_dgram_fallback.
+ * Therefore, we try to initialize it to transport_dgram_fallback
+ * so that we don't break VMCI. If VMCI is not present, it is okay
+ * to leave the transport empty since vsk->transport != NULL checks
+ * will be performed in send and receive paths.
+ *
+ * During vsock_dgram_connect(), since remote_cid is available,
+ * the right transport is assigned after lookup.
+ */
+ if (create_sock)
+ new_transport = transport_dgram_fallback;
+ else
+ new_transport = vsock_dgram_lookup_transport(remote_cid,
+ remote_flags);
break;
case SOCK_STREAM:
case SOCK_SEQPACKET:
@@ -501,6 +516,10 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
vsock_deassign_transport(vsk);
}
+ /* Only allow empty transport during vsock_create() for datagram */
+ if (!new_transport && sk->sk_type == SOCK_DGRAM && create_sock)
+ return 0;
+
/* We increase the module refcnt to prevent the transport unloading
* while there are open sockets assigned to it.
*/
@@ -525,6 +544,23 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
return 0;
}
+
+/* Assign a transport to a socket and call the .init transport callback.
+ *
+ * Note: for connection oriented socket this must be called when vsk->remote_addr
+ * is set (e.g. during the connect() or when a connection request on a listener
+ * socket is received).
+ * The vsk->remote_addr is used to decide which transport to use:
+ * - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
+ * g2h is not loaded, will use local transport;
+ * - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field
+ * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport;
+ * - remote CID > VMADDR_CID_HOST will use host->guest transport;
+ */
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+{
+ return __vsock_assign_transport(vsk, psk, false);
+}
EXPORT_SYMBOL_GPL(vsock_assign_transport);
bool vsock_find_cid(unsigned int cid)
@@ -693,6 +729,9 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
static int __vsock_bind_dgram(struct vsock_sock *vsk,
struct sockaddr_vm *addr)
{
+ if (!vsk->transport || !vsk->transport->dgram_bind)
+ return -EINVAL;
+
return vsk->transport->dgram_bind(vsk, addr);
}
@@ -825,6 +864,9 @@ static void __vsock_release(struct sock *sk, int level)
vsk->transport->release(vsk);
else if (sock_type_connectible(sk->sk_type))
vsock_remove_sock(vsk);
+ else if (sk->sk_type == SOCK_DGRAM &&
+ (!vsk->transport || !vsk->transport->dgram_bind))
+ vsock_remove_sock(vsk);
sock_orphan(sk);
sk->sk_shutdown = SHUTDOWN_MASK;
@@ -1152,6 +1194,9 @@ static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor)
{
struct vsock_sock *vsk = vsock_sk(sk);
+ if (!vsk->transport)
+ return -EINVAL;
+
return vsk->transport->read_skb(vsk, read_actor);
}
@@ -1163,6 +1208,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
struct vsock_sock *vsk;
struct sockaddr_vm *remote_addr;
const struct vsock_transport *transport;
+ bool module_got = false;
if (msg->msg_flags & MSG_OOB)
return -EOPNOTSUPP;
@@ -1174,19 +1220,40 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
lock_sock(sk);
- transport = vsk->transport;
-
err = vsock_auto_bind(vsk);
if (err)
goto out;
-
/* If the provided message contains an address, use that. Otherwise
* fall back on the socket's remote handle (if it has been connected).
*/
if (msg->msg_name &&
vsock_addr_cast(msg->msg_name, msg->msg_namelen,
&remote_addr) == 0) {
+ transport = vsock_dgram_lookup_transport(remote_addr->svm_cid,
+ remote_addr->svm_flags);
+ /* transport_dgram_fallback needs to be initialized to be called */
+ if (transport == transport_dgram_fallback && transport != vsk->transport) {
+ err = -EINVAL;
+ goto out;
+ }
+
+ if (!transport) {
+ err = -EINVAL;
+ goto out;
+ }
+
+ if (!try_module_get(transport->module)) {
+ err = -ENODEV;
+ goto out;
+ }
+
+ /* When looking up a transport dynamically and acquiring a
+ * reference on the module, we need to remember to release the
+ * reference later.
+ */
+ module_got = true;
+
/* Ensure this address is of the right type and is a valid
* destination.
*/
@@ -1201,6 +1268,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
} else if (sock->state == SS_CONNECTED) {
remote_addr = &vsk->remote_addr;
+ transport = vsk->transport;
if (remote_addr->svm_cid == VMADDR_CID_ANY)
remote_addr->svm_cid = transport->get_local_cid();
@@ -1225,6 +1293,8 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
out:
+ if (module_got)
+ module_put(transport->module);
release_sock(sk);
return err;
}
@@ -1257,13 +1327,18 @@ static int vsock_dgram_connect(struct socket *sock,
if (err)
goto out;
+ memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
+
+ err = vsock_assign_transport(vsk, NULL);
+ if (err)
+ goto out;
+
if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
remote_addr->svm_port)) {
err = -EINVAL;
goto out;
}
- memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
sock->state = SS_CONNECTED;
/* sock map disallows redirection of non-TCP sockets with sk_state !=
@@ -2406,7 +2481,7 @@ static int vsock_create(struct net *net, struct socket *sock,
vsk = vsock_sk(sk);
if (sock->type == SOCK_DGRAM) {
- ret = vsock_assign_transport(vsk, NULL);
+ ret = __vsock_assign_transport(vsk, NULL, true);
if (ret < 0) {
sock_put(sk);
return ret;
@@ -2548,7 +2623,7 @@ int vsock_core_register(const struct vsock_transport *t, int features)
t_h2g = transport_h2g;
t_g2h = transport_g2h;
- t_dgram = transport_dgram;
+ t_dgram = transport_dgram_fallback;
t_local = transport_local;
if (features & VSOCK_TRANSPORT_F_H2G) {
@@ -2567,7 +2642,7 @@ int vsock_core_register(const struct vsock_transport *t, int features)
t_g2h = t;
}
- if (features & VSOCK_TRANSPORT_F_DGRAM) {
+ if (features & VSOCK_TRANSPORT_F_DGRAM_FALLBACK) {
if (t_dgram) {
err = -EBUSY;
goto err_busy;
@@ -2585,7 +2660,7 @@ int vsock_core_register(const struct vsock_transport *t, int features)
transport_h2g = t_h2g;
transport_g2h = t_g2h;
- transport_dgram = t_dgram;
+ transport_dgram_fallback = t_dgram;
transport_local = t_local;
err_busy:
@@ -2604,8 +2679,8 @@ void vsock_core_unregister(const struct vsock_transport *t)
if (transport_g2h == t)
transport_g2h = NULL;
- if (transport_dgram == t)
- transport_dgram = NULL;
+ if (transport_dgram_fallback == t)
+ transport_dgram_fallback = NULL;
if (transport_local == t)
transport_local = NULL;
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index 326dd41ee2d5..64ad87a3879c 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -551,11 +551,6 @@ static void hvs_destruct(struct vsock_sock *vsk)
kfree(hvs);
}
-static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
-{
- return -EOPNOTSUPP;
-}
-
static int hvs_dgram_enqueue(struct vsock_sock *vsk,
struct sockaddr_vm *remote, struct msghdr *msg,
size_t dgram_len)
@@ -826,7 +821,6 @@ static struct vsock_transport hvs_transport = {
.connect = hvs_connect,
.shutdown = hvs_shutdown,
- .dgram_bind = hvs_dgram_bind,
.dgram_enqueue = hvs_dgram_enqueue,
.dgram_allow = hvs_dgram_allow,
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index a8c97e95622a..4891b845fcde 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -507,7 +507,6 @@ static struct virtio_transport virtio_transport = {
.shutdown = virtio_transport_shutdown,
.cancel_pkt = virtio_transport_cancel_pkt,
- .dgram_bind = virtio_transport_dgram_bind,
.dgram_enqueue = virtio_transport_dgram_enqueue,
.dgram_allow = virtio_transport_dgram_allow,
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index 4bf73d20c12a..a1c76836d798 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -1008,13 +1008,6 @@ bool virtio_transport_stream_allow(u32 cid, u32 port)
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
-int virtio_transport_dgram_bind(struct vsock_sock *vsk,
- struct sockaddr_vm *addr)
-{
- return -EOPNOTSUPP;
-}
-EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
-
bool virtio_transport_dgram_allow(u32 cid, u32 port)
{
return false;
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index b39df3ed8c8d..49aba9c48415 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -2061,7 +2061,7 @@ static int __init vmci_transport_init(void)
/* Register only with dgram feature, other features (H2G, G2H) will be
* registered when the first host or guest becomes active.
*/
- err = vsock_core_register(&vmci_transport, VSOCK_TRANSPORT_F_DGRAM);
+ err = vsock_core_register(&vmci_transport, VSOCK_TRANSPORT_F_DGRAM_FALLBACK);
if (err < 0)
goto err_unsubscribe;
diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c
index 11488887a5cc..4dd4886f29d1 100644
--- a/net/vmw_vsock/vsock_loopback.c
+++ b/net/vmw_vsock/vsock_loopback.c
@@ -65,7 +65,6 @@ static struct virtio_transport loopback_transport = {
.shutdown = virtio_transport_shutdown,
.cancel_pkt = vsock_loopback_cancel_pkt,
- .dgram_bind = virtio_transport_dgram_bind,
.dgram_enqueue = virtio_transport_dgram_enqueue,
.dgram_allow = virtio_transport_dgram_allow,
--
2.20.1