[RFC PATCH 07/10] tcp: implement sendmsg() TX path for for devmem tcp

From: Mina Almasry
Date: Mon Jul 10 2023 - 18:35:19 EST


For device memory TCP, we let the user provide the kernel with a cmsg
container 2 items:

1. the dmabuf pages fd that the user would like to send data from.
2. the offset into this dmabuf that the user would like to start sending
from.

In tcp_sendmsg_locked(), if this cmsg is provided, we send the data
using the dmabuf NET_TX pages bio_vec.

Also provide drivers with a new skb_devmem_frag_dma_map() helper. This
helper is similar to skb_frag_dma_map(), but it first checks whether the
frag being mapped is backed by dmabuf NET_TX pages, and provides the
correct dma_addr if so.

Signed-off-by: Mina Almasry <almasrymina@xxxxxxxxxx>
---
include/linux/skbuff.h | 19 +++++++++--
include/net/sock.h | 2 ++
net/core/skbuff.c | 8 ++---
net/core/sock.c | 6 ++++
net/ipv4/tcp.c | 73 +++++++++++++++++++++++++++++++++++++++++-
5 files changed, 101 insertions(+), 7 deletions(-)

diff --git a/include/linux/skbuff.h b/include/linux/skbuff.h
index f5e03aa84160..ad4e7bfcab07 100644
--- a/include/linux/skbuff.h
+++ b/include/linux/skbuff.h
@@ -1660,8 +1660,8 @@ static inline int skb_zerocopy_iter_dgram(struct sk_buff *skb,
}

int skb_zerocopy_iter_stream(struct sock *sk, struct sk_buff *skb,
- struct msghdr *msg, int len,
- struct ubuf_info *uarg);
+ struct msghdr *msg, struct iov_iter *iov_iter,
+ int len, struct ubuf_info *uarg);

/* Internal */
#define skb_shinfo(SKB) ((struct skb_shared_info *)(skb_end_pointer(SKB)))
@@ -3557,6 +3557,21 @@ static inline dma_addr_t skb_frag_dma_map(struct device *dev,
skb_frag_off(frag) + offset, size, dir);
}

+/* Similar to skb_frag_dma_map, but handles devmem skbs correctly. */
+static inline dma_addr_t skb_devmem_frag_dma_map(struct device *dev,
+ const struct sk_buff *skb,
+ const skb_frag_t *frag,
+ size_t offset, size_t size,
+ enum dma_data_direction dir)
+{
+ if (unlikely(skb->devmem && is_dma_buf_page(skb_frag_page(frag)))) {
+ dma_addr_t dma_addr =
+ dma_buf_page_to_dma_addr(skb_frag_page(frag));
+ return dma_addr + skb_frag_off(frag) + offset;
+ }
+ return skb_frag_dma_map(dev, frag, offset, size, dir);
+}
+
static inline struct sk_buff *pskb_copy(struct sk_buff *skb,
gfp_t gfp_mask)
{
diff --git a/include/net/sock.h b/include/net/sock.h
index c615666ff19a..733865f89635 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1890,6 +1890,8 @@ struct sockcm_cookie {
u64 transmit_time;
u32 mark;
u32 tsflags;
+ u32 devmem_fd;
+ u32 devmem_offset;
};

static inline void sockcm_init(struct sockcm_cookie *sockc,
diff --git a/net/core/skbuff.c b/net/core/skbuff.c
index 9b83da794641..b1e28e7ad6a8 100644
--- a/net/core/skbuff.c
+++ b/net/core/skbuff.c
@@ -1685,8 +1685,8 @@ void msg_zerocopy_put_abort(struct ubuf_info *uarg, bool have_uref)
EXPORT_SYMBOL_GPL(msg_zerocopy_put_abort);

int skb_zerocopy_iter_stream(struct sock *sk, struct sk_buff *skb,
- struct msghdr *msg, int len,
- struct ubuf_info *uarg)
+ struct msghdr *msg, struct iov_iter *iov_iter,
+ int len, struct ubuf_info *uarg)
{
struct ubuf_info *orig_uarg = skb_zcopy(skb);
int err, orig_len = skb->len;
@@ -1697,12 +1697,12 @@ int skb_zerocopy_iter_stream(struct sock *sk, struct sk_buff *skb,
if (orig_uarg && uarg != orig_uarg)
return -EEXIST;

- err = __zerocopy_sg_from_iter(msg, sk, skb, &msg->msg_iter, len);
+ err = __zerocopy_sg_from_iter(msg, sk, skb, iov_iter, len);
if (err == -EFAULT || (err == -EMSGSIZE && skb->len == orig_len)) {
struct sock *save_sk = skb->sk;

/* Streams do not free skb on error. Reset to prev state. */
- iov_iter_revert(&msg->msg_iter, skb->len - orig_len);
+ iov_iter_revert(iov_iter, skb->len - orig_len);
skb->sk = sk;
___pskb_trim(skb, orig_len);
skb->sk = save_sk;
diff --git a/net/core/sock.c b/net/core/sock.c
index f9b9d9ec7322..854624bee5d0 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -2813,6 +2813,12 @@ int __sock_cmsg_send(struct sock *sk, struct cmsghdr *cmsg,
return -EINVAL;
sockc->transmit_time = get_unaligned((u64 *)CMSG_DATA(cmsg));
break;
+ case SCM_DEVMEM_OFFSET:
+ if (cmsg->cmsg_len != CMSG_LEN(2 * sizeof(u32)))
+ return -EINVAL;
+ sockc->devmem_fd = ((u32 *)CMSG_DATA(cmsg))[0];
+ sockc->devmem_offset = ((u32 *)CMSG_DATA(cmsg))[1];
+ break;
/* SCM_RIGHTS and SCM_CREDENTIALS are semantically in SOL_UNIX. */
case SCM_RIGHTS:
case SCM_CREDENTIALS:
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index a894b8a9dbb0..85d6cdc832ef 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -280,6 +280,7 @@
#include <asm/ioctls.h>
#include <net/busy_poll.h>
#include <linux/dma-buf.h>
+#include <uapi/linux/dma-buf.h>

/* Track pending CMSGs. */
enum {
@@ -1216,6 +1217,52 @@ int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, int *copied,
return err;
}

+static int tcp_prepare_devmem_data(struct msghdr *msg, int devmem_fd,
+ unsigned int devmem_offset,
+ struct file **devmem_file,
+ struct iov_iter *devmem_tx_iter, size_t size)
+{
+ struct dma_buf_pages *priv;
+ int err = 0;
+
+ *devmem_file = fget_raw(devmem_fd);
+ if (!*devmem_file) {
+ err = -EINVAL;
+ goto err;
+ }
+
+ if (!is_dma_buf_pages_file(*devmem_file)) {
+ err = -EBADF;
+ goto err_fput;
+ }
+
+ priv = (*devmem_file)->private_data;
+ if (!priv) {
+ WARN_ONCE(!priv, "dma_buf_pages_file has no private_data");
+ err = -EINTR;
+ goto err_fput;
+ }
+
+ if (!(priv->type & DMA_BUF_PAGES_NET_TX))
+ return -EINVAL;
+
+ if (devmem_offset + size > priv->dmabuf->size) {
+ err = -ENOSPC;
+ goto err_fput;
+ }
+
+ *devmem_tx_iter = priv->net_tx.iter;
+ iov_iter_advance(devmem_tx_iter, devmem_offset);
+
+ return 0;
+
+err_fput:
+ fput(*devmem_file);
+ *devmem_file = NULL;
+err:
+ return err;
+}
+
int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
{
struct tcp_sock *tp = tcp_sk(sk);
@@ -1227,6 +1274,8 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
int process_backlog = 0;
bool zc = false;
long timeo;
+ struct file *devmem_file = NULL;
+ struct iov_iter devmem_tx_iter;

flags = msg->msg_flags;

@@ -1295,6 +1344,14 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
}
}

+ if (sockc.devmem_fd) {
+ err = tcp_prepare_devmem_data(msg, sockc.devmem_fd,
+ sockc.devmem_offset, &devmem_file,
+ &devmem_tx_iter, size);
+ if (err)
+ goto out_err;
+ }
+
/* This should be in poll */
sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);

@@ -1408,7 +1465,17 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
goto wait_for_space;
}

- err = skb_zerocopy_iter_stream(sk, skb, msg, copy, uarg);
+ if (devmem_file) {
+ err = skb_zerocopy_iter_stream(sk, skb, msg,
+ &devmem_tx_iter,
+ copy, uarg);
+ if (err > 0)
+ iov_iter_advance(&msg->msg_iter, err);
+ } else {
+ err = skb_zerocopy_iter_stream(sk, skb, msg,
+ &msg->msg_iter,
+ copy, uarg);
+ }
if (err == -EMSGSIZE || err == -EEXIST) {
tcp_mark_push(tp, skb);
goto new_segment;
@@ -1462,6 +1529,8 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
}
out_nopush:
net_zcopy_put(uarg);
+ if (devmem_file)
+ fput(devmem_file);
return copied + copied_syn;

do_error:
@@ -1470,6 +1539,8 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
if (copied + copied_syn)
goto out;
out_err:
+ if (devmem_file)
+ fput(devmem_file);
net_zcopy_put_abort(uarg, true);
err = sk_stream_error(sk, flags, err);
/* make sure we wake any epoll edge trigger waiter */
--
2.41.0.390.g38632f3daf-goog