Re: [PATCH 3/5] af_vsock: send/receive loops for SOCK_SEQPACKET.

From: stsp
Date: Sun Jan 03 2021 - 15:51:58 EST


Hi Arseny!

03.01.2021 23:03, Arseny Krasnov пишет:
From: Arseniy Krasnov <oxffffaa@xxxxxxxxx>

For send, this patch adds:
1) Send of record begin marker with record length.
2) Return error if send of whole record is failed.

For receive, this patch adds another loop, it looks like
stream loop, but:
1) It doesn't call notify callbacks.
2) It doesn't care about 'SO_SNDLOWAT' and 'SO_RCVLOWAT'
values.
3) It waits until whole record is received or error is
found during receiving.
3) It processes and sets 'MSG_TRUNC' flag.
---
net/vmw_vsock/af_vsock.c | 319 +++++++++++++++++++++++++++++++--------
1 file changed, 256 insertions(+), 63 deletions(-)

diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index b12d3a322242..7ff00449a9a2 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -1683,8 +1683,8 @@ static int vsock_stream_getsockopt(struct socket *sock,
return 0;
}
-static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
- size_t len)
+static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
+ size_t len)
{
struct sock *sk;
struct vsock_sock *vsk;
@@ -1737,6 +1737,12 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
if (err < 0)
goto out;
+ if (sk->sk_type == SOCK_SEQPACKET) {
+ err = transport->seqpacket_seq_send_len(vsk, len);
+ if (err < 0)
+ goto out;
+ }
+
while (total_written < len) {
ssize_t written;
@@ -1796,10 +1802,8 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
* smaller than the queue size. It is the caller's
* responsibility to check how many bytes we were able to send.
*/
-
- written = transport->stream_enqueue(
- vsk, msg,
- len - total_written);
+ written = transport->stream_enqueue(vsk, msg,
+ len - total_written);

White-space change?


if (written < 0) {
err = -ENOMEM;
goto out_err;
@@ -1815,36 +1819,96 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
}
out_err:
- if (total_written > 0)
- err = total_written;
+ if (total_written > 0) {
+ /* Return number of written bytes only if:
+ * 1) SOCK_STREAM socket.
+ * 2) SOCK_SEQPACKET socket when whole buffer is sent.
+ */
+ if (sk->sk_type == SOCK_STREAM || total_written == len)
+ err = total_written;
+ }
out:
release_sock(sk);
return err;
}
+static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
+ size_t len)
+{
+ return vsock_connectible_sendmsg(sock, msg, len);
+}
-static int
-vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
- int flags)
+static int vsock_seqpacket_sendmsg(struct socket *sock, struct msghdr *msg,
+ size_t len)
{
- struct sock *sk;
+ return vsock_connectible_sendmsg(sock, msg, len);
+}
+
+static int vsock_wait_data(struct sock *sk, struct wait_queue_entry *wait,
+ long timeout,
+ struct vsock_transport_recv_notify_data *recv_data,
+ size_t target)
+{

You patch looks quite large because
of this, so would it make sense to separate
out the refactoring part (vsock_wait_data()
and friends that you seem to copy out of
recvmsg() code) as the separate patch?
Currently its a bit difficult to see what was
added and what was "refactored".


+ int err = 0;
struct vsock_sock *vsk;
const struct vsock_transport *transport;
- int err;
- size_t target;
- ssize_t copied;
- long timeout;
- struct vsock_transport_recv_notify_data recv_data;
-
- DEFINE_WAIT(wait);
- sk = sock->sk;
vsk = vsock_sk(sk);
transport = vsk->transport;
- err = 0;
+ if (sk->sk_err != 0 ||
+ (sk->sk_shutdown & RCV_SHUTDOWN) ||
+ (vsk->peer_shutdown & SEND_SHUTDOWN)) {
+ finish_wait(sk_sleep(sk), wait);
+ return -1;
+ }
+ /* Don't wait for non-blocking sockets. */
+ if (timeout == 0) {
+ err = -EAGAIN;
+ finish_wait(sk_sleep(sk), wait);
+ return err;
+ }
+
+ if (sk->sk_type == SOCK_STREAM) {
+ err = transport->notify_recv_pre_block(vsk, target, recv_data);
+ if (err < 0) {
+ finish_wait(sk_sleep(sk), wait);
+ return err;
+ }
+ }
+
+ release_sock(sk);
+ timeout = schedule_timeout(timeout);
lock_sock(sk);
+ if (signal_pending(current)) {
+ err = sock_intr_errno(timeout);
+ finish_wait(sk_sleep(sk), wait);
+ } else if (timeout == 0) {
+ err = -EAGAIN;
+ finish_wait(sk_sleep(sk), wait);
+ }
+
+ return err;
+}
+
+static int vsock_wait_data_seqpacket(struct sock *sk, struct wait_queue_entry *wait,
+ long timeout)
+{
+ return vsock_wait_data(sk, wait, timeout, NULL, 0);

Would it make sense to structure that
differently? If vsock_wait_data() does
"more things" than vsock_wait_data_seqpacket(),
then would it be possible to make
vsock_wait_data() to call vsock_wait_data_seqpacket()
(or some common part of both), rather
than to null out unused arguments?


+}
+
+static int vsock_pre_recv_check(struct socket *sock,
+ int flags, size_t len, int *err)
+{
+ struct sock *sk;
+ struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
+
+ sk = sock->sk;
+ vsk = vsock_sk(sk);
+ transport = vsk->transport;
+
if (!transport || sk->sk_state != TCP_ESTABLISHED) {
/* Recvmsg is supposed to return 0 if a peer performs an
* orderly shutdown. Differentiate between that case and when a
@@ -1852,16 +1916,16 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
* SOCK_DONE flag.
*/
if (sock_flag(sk, SOCK_DONE))
- err = 0;
+ *err = 0;
else
- err = -ENOTCONN;
+ *err = -ENOTCONN;
- goto out;
+ return false;

Hmm, are you sure you need to convert
"err" to the pointer, just to return true/false
as the return value?
How about still returning "err" itself?


}
if (flags & MSG_OOB) {
- err = -EOPNOTSUPP;
- goto out;
+ *err = -EOPNOTSUPP;
+ return false;
}
/* We don't check peer_shutdown flag here since peer may actually shut
@@ -1869,17 +1933,143 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
* receive.
*/
if (sk->sk_shutdown & RCV_SHUTDOWN) {
- err = 0;
- goto out;
+ *err = 0;
+ return false;
}
/* It is valid on Linux to pass in a zero-length receive buffer. This
* is not an error. We may as well bail out now.
*/
if (!len) {
+ *err = 0;
+ return false;
+ }
+
+ return true;
+}
+
+static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
+ size_t len, int flags)
+{
+ int err = 0;
+ size_t record_len;
+ struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
+ long timeout;
+ ssize_t dequeued_total = 0;
+ unsigned long orig_nr_segs;
+ const struct iovec *orig_iov;
+ DEFINE_WAIT(wait);
+
+ vsk = vsock_sk(sk);
+ transport = vsk->transport;
+
+ timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+ msg->msg_flags &= ~MSG_EOR;
+ orig_nr_segs = msg->msg_iter.nr_segs;
+ orig_iov = msg->msg_iter.iov;
+
+ while (1) {
+ s64 ready;
+
+ prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+ ready = vsock_stream_has_data(vsk);
+
+ if (ready == 0) {
+ if (vsock_wait_data_seqpacket(sk, &wait, timeout)) {
+ /* In case of any loop break(timeout, signal
+ * interrupt or shutdown), we report user that
+ * nothing was copied.
+ */
+ dequeued_total = 0;
+ break;
+ }
+ } else {
+ ssize_t dequeued;
+
+ finish_wait(sk_sleep(sk), &wait);
+
+ if (ready < 0) {
+ err = -ENOMEM;
+ goto out;
+ }
+
+ if (dequeued_total == 0) {
+ record_len =
+ transport->seqpacket_seq_get_len(vsk);
+
+ if (record_len == 0)
+ continue;
+ }
+
+ /* 'msg_iter.count' is number of unused bytes in iov.
+ * On every copy to iov iterator it is decremented at
+ * size of data.
+ */
+ dequeued = transport->stream_dequeue(vsk, msg,
+ msg->msg_iter.count, flags);
+
+ if (dequeued < 0) {
+ dequeued_total = 0;
+
+ if (dequeued == -EAGAIN) {
+ iov_iter_init(&msg->msg_iter, READ,
+ orig_iov, orig_nr_segs,
+ len);
+ msg->msg_flags &= ~MSG_EOR;
+ continue;
+ }
+
+ err = -ENOMEM;
+ break;
+ }
+
+ dequeued_total += dequeued;
+
+ if (dequeued_total >= record_len)
+ break;
+ }
+ }
+
+ if (sk->sk_err)
+ err = -sk->sk_err;
+ else if (sk->sk_shutdown & RCV_SHUTDOWN)
err = 0;
- goto out;
+
+ if (dequeued_total > 0) {
+ /* User sets MSG_TRUNC, so return real length of
+ * packet.
+ */
+ if (flags & MSG_TRUNC)
+ err = record_len;
+ else
+ err = len - msg->msg_iter.count;

Its not very clear (only for me perhaps) how
dequeue_total and len correlate. Are they
equal here? Would you need to check that
dequeued_total >= record_len?
I mean, its just a bit strange that you check
dequeued_total>0 and no longer use that var
inside the block.