[PATCH 3/5] af_vsock: send/receive loops for SOCK_SEQPACKET.
From: Arseny Krasnov
Date: Sun Jan 03 2021 - 15:05:07 EST
From: Arseniy Krasnov <oxffffaa@xxxxxxxxx>
For send, this patch adds:
1) Send of record begin marker with record length.
2) Return error if send of whole record is failed.
For receive, this patch adds another loop, it looks like
stream loop, but:
1) It doesn't call notify callbacks.
2) It doesn't care about 'SO_SNDLOWAT' and 'SO_RCVLOWAT'
values.
3) It waits until whole record is received or error is
found during receiving.
3) It processes and sets 'MSG_TRUNC' flag.
---
net/vmw_vsock/af_vsock.c | 319 +++++++++++++++++++++++++++++++--------
1 file changed, 256 insertions(+), 63 deletions(-)
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index b12d3a322242..7ff00449a9a2 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -1683,8 +1683,8 @@ static int vsock_stream_getsockopt(struct socket *sock,
return 0;
}
-static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
- size_t len)
+static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
+ size_t len)
{
struct sock *sk;
struct vsock_sock *vsk;
@@ -1737,6 +1737,12 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
if (err < 0)
goto out;
+ if (sk->sk_type == SOCK_SEQPACKET) {
+ err = transport->seqpacket_seq_send_len(vsk, len);
+ if (err < 0)
+ goto out;
+ }
+
while (total_written < len) {
ssize_t written;
@@ -1796,10 +1802,8 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
* smaller than the queue size. It is the caller's
* responsibility to check how many bytes we were able to send.
*/
-
- written = transport->stream_enqueue(
- vsk, msg,
- len - total_written);
+ written = transport->stream_enqueue(vsk, msg,
+ len - total_written);
if (written < 0) {
err = -ENOMEM;
goto out_err;
@@ -1815,36 +1819,96 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
}
out_err:
- if (total_written > 0)
- err = total_written;
+ if (total_written > 0) {
+ /* Return number of written bytes only if:
+ * 1) SOCK_STREAM socket.
+ * 2) SOCK_SEQPACKET socket when whole buffer is sent.
+ */
+ if (sk->sk_type == SOCK_STREAM || total_written == len)
+ err = total_written;
+ }
out:
release_sock(sk);
return err;
}
+static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
+ size_t len)
+{
+ return vsock_connectible_sendmsg(sock, msg, len);
+}
-static int
-vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
- int flags)
+static int vsock_seqpacket_sendmsg(struct socket *sock, struct msghdr *msg,
+ size_t len)
{
- struct sock *sk;
+ return vsock_connectible_sendmsg(sock, msg, len);
+}
+
+static int vsock_wait_data(struct sock *sk, struct wait_queue_entry *wait,
+ long timeout,
+ struct vsock_transport_recv_notify_data *recv_data,
+ size_t target)
+{
+ int err = 0;
struct vsock_sock *vsk;
const struct vsock_transport *transport;
- int err;
- size_t target;
- ssize_t copied;
- long timeout;
- struct vsock_transport_recv_notify_data recv_data;
-
- DEFINE_WAIT(wait);
- sk = sock->sk;
vsk = vsock_sk(sk);
transport = vsk->transport;
- err = 0;
+ if (sk->sk_err != 0 ||
+ (sk->sk_shutdown & RCV_SHUTDOWN) ||
+ (vsk->peer_shutdown & SEND_SHUTDOWN)) {
+ finish_wait(sk_sleep(sk), wait);
+ return -1;
+ }
+ /* Don't wait for non-blocking sockets. */
+ if (timeout == 0) {
+ err = -EAGAIN;
+ finish_wait(sk_sleep(sk), wait);
+ return err;
+ }
+
+ if (sk->sk_type == SOCK_STREAM) {
+ err = transport->notify_recv_pre_block(vsk, target, recv_data);
+ if (err < 0) {
+ finish_wait(sk_sleep(sk), wait);
+ return err;
+ }
+ }
+
+ release_sock(sk);
+ timeout = schedule_timeout(timeout);
lock_sock(sk);
+ if (signal_pending(current)) {
+ err = sock_intr_errno(timeout);
+ finish_wait(sk_sleep(sk), wait);
+ } else if (timeout == 0) {
+ err = -EAGAIN;
+ finish_wait(sk_sleep(sk), wait);
+ }
+
+ return err;
+}
+
+static int vsock_wait_data_seqpacket(struct sock *sk, struct wait_queue_entry *wait,
+ long timeout)
+{
+ return vsock_wait_data(sk, wait, timeout, NULL, 0);
+}
+
+static int vsock_pre_recv_check(struct socket *sock,
+ int flags, size_t len, int *err)
+{
+ struct sock *sk;
+ struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
+
+ sk = sock->sk;
+ vsk = vsock_sk(sk);
+ transport = vsk->transport;
+
if (!transport || sk->sk_state != TCP_ESTABLISHED) {
/* Recvmsg is supposed to return 0 if a peer performs an
* orderly shutdown. Differentiate between that case and when a
@@ -1852,16 +1916,16 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
* SOCK_DONE flag.
*/
if (sock_flag(sk, SOCK_DONE))
- err = 0;
+ *err = 0;
else
- err = -ENOTCONN;
+ *err = -ENOTCONN;
- goto out;
+ return false;
}
if (flags & MSG_OOB) {
- err = -EOPNOTSUPP;
- goto out;
+ *err = -EOPNOTSUPP;
+ return false;
}
/* We don't check peer_shutdown flag here since peer may actually shut
@@ -1869,17 +1933,143 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
* receive.
*/
if (sk->sk_shutdown & RCV_SHUTDOWN) {
- err = 0;
- goto out;
+ *err = 0;
+ return false;
}
/* It is valid on Linux to pass in a zero-length receive buffer. This
* is not an error. We may as well bail out now.
*/
if (!len) {
+ *err = 0;
+ return false;
+ }
+
+ return true;
+}
+
+static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
+ size_t len, int flags)
+{
+ int err = 0;
+ size_t record_len;
+ struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
+ long timeout;
+ ssize_t dequeued_total = 0;
+ unsigned long orig_nr_segs;
+ const struct iovec *orig_iov;
+ DEFINE_WAIT(wait);
+
+ vsk = vsock_sk(sk);
+ transport = vsk->transport;
+
+ timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+ msg->msg_flags &= ~MSG_EOR;
+ orig_nr_segs = msg->msg_iter.nr_segs;
+ orig_iov = msg->msg_iter.iov;
+
+ while (1) {
+ s64 ready;
+
+ prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+ ready = vsock_stream_has_data(vsk);
+
+ if (ready == 0) {
+ if (vsock_wait_data_seqpacket(sk, &wait, timeout)) {
+ /* In case of any loop break(timeout, signal
+ * interrupt or shutdown), we report user that
+ * nothing was copied.
+ */
+ dequeued_total = 0;
+ break;
+ }
+ } else {
+ ssize_t dequeued;
+
+ finish_wait(sk_sleep(sk), &wait);
+
+ if (ready < 0) {
+ err = -ENOMEM;
+ goto out;
+ }
+
+ if (dequeued_total == 0) {
+ record_len =
+ transport->seqpacket_seq_get_len(vsk);
+
+ if (record_len == 0)
+ continue;
+ }
+
+ /* 'msg_iter.count' is number of unused bytes in iov.
+ * On every copy to iov iterator it is decremented at
+ * size of data.
+ */
+ dequeued = transport->stream_dequeue(vsk, msg,
+ msg->msg_iter.count, flags);
+
+ if (dequeued < 0) {
+ dequeued_total = 0;
+
+ if (dequeued == -EAGAIN) {
+ iov_iter_init(&msg->msg_iter, READ,
+ orig_iov, orig_nr_segs,
+ len);
+ msg->msg_flags &= ~MSG_EOR;
+ continue;
+ }
+
+ err = -ENOMEM;
+ break;
+ }
+
+ dequeued_total += dequeued;
+
+ if (dequeued_total >= record_len)
+ break;
+ }
+ }
+
+ if (sk->sk_err)
+ err = -sk->sk_err;
+ else if (sk->sk_shutdown & RCV_SHUTDOWN)
err = 0;
- goto out;
+
+ if (dequeued_total > 0) {
+ /* User sets MSG_TRUNC, so return real length of
+ * packet.
+ */
+ if (flags & MSG_TRUNC)
+ err = record_len;
+ else
+ err = len - msg->msg_iter.count;
+
+ /* Always set MSG_TRUNC if real length of packet is
+ * bigger that user buffer.
+ */
+ if (record_len > len)
+ msg->msg_flags |= MSG_TRUNC;
}
+out:
+ return err;
+}
+
+static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg,
+ size_t len, int flags)
+{
+ int err;
+ const struct vsock_transport *transport;
+ struct vsock_sock *vsk;
+ size_t target;
+ struct vsock_transport_recv_notify_data recv_data;
+ long timeout;
+ ssize_t copied;
+
+ DEFINE_WAIT(wait);
+
+ vsk = vsock_sk(sk);
+ transport = vsk->transport;
/* We must not copy less than target bytes into the user's buffer
* before returning successfully, so we wait for the consume queue to
@@ -1907,38 +2097,8 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
ready = vsock_stream_has_data(vsk);
if (ready == 0) {
- if (sk->sk_err != 0 ||
- (sk->sk_shutdown & RCV_SHUTDOWN) ||
- (vsk->peer_shutdown & SEND_SHUTDOWN)) {
- finish_wait(sk_sleep(sk), &wait);
- break;
- }
- /* Don't wait for non-blocking sockets. */
- if (timeout == 0) {
- err = -EAGAIN;
- finish_wait(sk_sleep(sk), &wait);
- break;
- }
-
- err = transport->notify_recv_pre_block(
- vsk, target, &recv_data);
- if (err < 0) {
- finish_wait(sk_sleep(sk), &wait);
+ if (vsock_wait_data(sk, &wait, timeout, &recv_data, target))
break;
- }
- release_sock(sk);
- timeout = schedule_timeout(timeout);
- lock_sock(sk);
-
- if (signal_pending(current)) {
- err = sock_intr_errno(timeout);
- finish_wait(sk_sleep(sk), &wait);
- break;
- } else if (timeout == 0) {
- err = -EAGAIN;
- finish_wait(sk_sleep(sk), &wait);
- break;
- }
} else {
ssize_t read;
@@ -1959,9 +2119,8 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
if (err < 0)
break;
- read = transport->stream_dequeue(
- vsk, msg,
- len - copied, flags);
+ read = transport->stream_dequeue(vsk, msg, len - copied, flags);
+
if (read < 0) {
err = -ENOMEM;
break;
@@ -1990,11 +2149,45 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
if (copied > 0)
err = copied;
+out:
+ return err;
+}
+
+static int vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg,
+ size_t len, int flags)
+{
+ struct sock *sk;
+ int err;
+
+ sk = sock->sk;
+
+ lock_sock(sk);
+
+ if (!vsock_pre_recv_check(sock, flags, len, &err))
+ goto out;
+
+ if (sk->sk_type == SOCK_STREAM)
+ err = __vsock_stream_recvmsg(sk, msg, len, flags);
+ else
+ err = __vsock_seqpacket_recvmsg(sk, msg, len, flags);
+
out:
release_sock(sk);
return err;
}
+static int vsock_seqpacket_recvmsg(struct socket *sock, struct msghdr *msg,
+ size_t len, int flags)
+{
+ return vsock_connectible_recvmsg(sock, msg, len, flags);
+}
+
+static int vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg,
+ size_t len, int flags)
+{
+ return vsock_connectible_recvmsg(sock, msg, len, flags);
+}
+
static const struct proto_ops vsock_stream_ops = {
.family = PF_VSOCK,
.owner = THIS_MODULE,
--
2.25.1