In tcp_recvmsg_locked(), detect if the skb being received by the user
is a devmem skb. In this case - if the user provided the MSG_SOCK_DEVMEM
flag - pass it to tcp_recvmsg_devmem() for custom handling.
tcp_recvmsg_devmem() copies any data in the skb header to the linear
buffer, and returns a cmsg to the user indicating the number of bytes
returned in the linear buffer.
tcp_recvmsg_devmem() then loops over the unaccessible devmem skb frags,
and returns to the user a cmsg_devmem indicating the location of the
data in the dmabuf device memory. cmsg_devmem contains this information:
1. the offset into the dmabuf where the payload starts. 'frag_offset'.
2. the size of the frag. 'frag_size'.
3. an opaque token 'frag_token' to return to the kernel when the buffer
is to be released.
The pages awaiting freeing are stored in the newly added
sk->sk_user_frags, and each page passed to userspace is get_page()'d.
This reference is dropped once the userspace indicates that it is
done reading this page. All pages are released when the socket is
destroyed.
Signed-off-by: Willem de Bruijn <willemb@xxxxxxxxxx>...
Signed-off-by: Kaiyuan Zhang <kaiyuanz@xxxxxxxxxx>
Signed-off-by: Mina Almasry <almasrymina@xxxxxxxxxx>
+static int tcp_xa_pool_refill(struct sock *sk, struct tcp_xa_pool *p,
+ unsigned int max_frags)
+{
+ int err, k;
+
+ if (p->idx < p->max)
+ return 0;
+
+ xa_lock_bh(&sk->sk_user_frags);
+
+ tcp_xa_pool_commit_locked(sk, p);
+
+ for (k = 0; k < max_frags; k++) {
+ err = __xa_alloc(&sk->sk_user_frags, &p->tokens[k],
+ XA_ZERO_ENTRY, xa_limit_31b, GFP_KERNEL);
+ if (err)
+ break;
+ }
+
+ xa_unlock_bh(&sk->sk_user_frags);
+
+ p->max = k;
+ p->idx = 0;
+ return k ? 0 : err;
+}
+
+/* On error, returns the -errno. On success, returns number of bytes sent to the
+ * user. May not consume all of @remaining_len.
+ */
+static int tcp_recvmsg_dmabuf(struct sock *sk, const struct sk_buff *skb,
+ unsigned int offset, struct msghdr *msg,
+ int remaining_len)
+{
+ struct dmabuf_cmsg dmabuf_cmsg = { 0 };
+ struct tcp_xa_pool tcp_xa_pool;
+ unsigned int start;
+ int i, copy, n;
+ int sent = 0;
+ int err = 0;
+
+ tcp_xa_pool.max = 0;
+ tcp_xa_pool.idx = 0;
+ do {
+ start = skb_headlen(skb);
+
+ if (skb_frags_readable(skb)) {
+ err = -ENODEV;
+ goto out;
+ }
+
+ /* Copy header. */
+ copy = start - offset;
+ if (copy > 0) {
+ copy = min(copy, remaining_len);
+
+ n = copy_to_iter(skb->data + offset, copy,
+ &msg->msg_iter);
+ if (n != copy) {
+ err = -EFAULT;
+ goto out;
+ }
+
+ offset += copy;
+ remaining_len -= copy;
+
+ /* First a dmabuf_cmsg for # bytes copied to user
+ * buffer.
+ */
+ memset(&dmabuf_cmsg, 0, sizeof(dmabuf_cmsg));
+ dmabuf_cmsg.frag_size = copy;
+ err = put_cmsg(msg, SOL_SOCKET, SO_DEVMEM_LINEAR,
+ sizeof(dmabuf_cmsg), &dmabuf_cmsg);
+ if (err || msg->msg_flags & MSG_CTRUNC) {
+ msg->msg_flags &= ~MSG_CTRUNC;
+ if (!err)
+ err = -ETOOSMALL;
+ goto out;
+ }
+
+ sent += copy;
+
+ if (remaining_len == 0)
+ goto out;
+ }
+
+ /* after that, send information of dmabuf pages through a
+ * sequence of cmsg
+ */
+ for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
+ skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
+ struct net_iov *niov;
+ u64 frag_offset;
+ int end;
+
+ /* !skb_frags_readable() should indicate that ALL the
+ * frags in this skb are dmabuf net_iovs. We're checking
+ * for that flag above, but also check individual frags
+ * here. If the tcp stack is not setting
+ * skb_frags_readable() correctly, we still don't want
+ * to crash here.
+ */
+ if (!skb_frag_net_iov(frag)) {
+ net_err_ratelimited("Found non-dmabuf skb with net_iov");
+ err = -ENODEV;
+ goto out;
+ }
+
+ niov = skb_frag_net_iov(frag);
+ end = start + skb_frag_size(frag);
+ copy = end - offset;
+
+ if (copy > 0) {
+ copy = min(copy, remaining_len);
+
+ frag_offset = net_iov_virtual_addr(niov) +
+ skb_frag_off(frag) + offset -
+ start;
+ dmabuf_cmsg.frag_offset = frag_offset;
+ dmabuf_cmsg.frag_size = copy;
+ err = tcp_xa_pool_refill(sk, &tcp_xa_pool,
+ skb_shinfo(skb)->nr_frags - i);
+ if (err)
+ goto out;
+
+ /* Will perform the exchange later */
+ dmabuf_cmsg.frag_token = tcp_xa_pool.tokens[tcp_xa_pool.idx];
+ dmabuf_cmsg.dmabuf_id = net_iov_binding_id(niov);
+
+ offset += copy;
+ remaining_len -= copy;
+
+ err = put_cmsg(msg, SOL_SOCKET,
+ SO_DEVMEM_DMABUF,
+ sizeof(dmabuf_cmsg),
+ &dmabuf_cmsg);
+ if (err || msg->msg_flags & MSG_CTRUNC) {
+ msg->msg_flags &= ~MSG_CTRUNC;
+ if (!err)
+ err = -ETOOSMALL;
+ goto out;
+ }
+
+ atomic_long_inc(&niov->pp_ref_count);
+ tcp_xa_pool.netmems[tcp_xa_pool.idx++] = skb_frag_netmem(frag);
+
+ sent += copy;
+
+ if (remaining_len == 0)
+ goto out;
+ }
+ start = end;
+ }
+
+ tcp_xa_pool_commit(sk, &tcp_xa_pool);
+ if (!remaining_len)
+ goto out;
+
+ /* if remaining_len is not satisfied yet, we need to go to the
+ * next frag in the frag_list to satisfy remaining_len.
+ */
+ skb = skb_shinfo(skb)->frag_list ?: skb->next;
+
+ offset = offset - start;
+ } while (skb);
+
+ if (remaining_len) {
+ err = -EFAULT;
+ goto out;
+ }
+...
+out:
+ tcp_xa_pool_commit(sk, &tcp_xa_pool);
+ if (!sent)
+ sent = err;
+
+ return sent;
+}
+
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index de0c8f43448ab..57e48b75ac02a 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -79,6 +79,7 @@
#include <linux/seq_file.h>
#include <linux/inetdevice.h>
#include <linux/btf_ids.h>
+#include <linux/skbuff_ref.h>
#include <crypto/hash.h>
#include <linux/scatterlist.h>
@@ -2503,6 +2504,15 @@ static void tcp_md5sig_info_free_rcu(struct rcu_head *head)
void tcp_v4_destroy_sock(struct sock *sk)
{
struct tcp_sock *tp = tcp_sk(sk);
+ __maybe_unused unsigned long index;
+ __maybe_unused void *netmem;
+
+#ifdef CONFIG_PAGE_POOL
+ xa_for_each(&sk->sk_user_frags, index, netmem)
+ WARN_ON_ONCE(!napi_pp_put_page((__force netmem_ref)netmem));
+#endif
+
+ xa_destroy(&sk->sk_user_frags);
trace_tcp_destroy_sock(sk);
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index bc67f6b9efae4..5d563312efe14 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -624,6 +624,8 @@ struct sock *tcp_create_openreq_child(const struct sock *sk,
__TCP_INC_STATS(sock_net(sk), TCP_MIB_PASSIVEOPENS);
+ xa_init_flags(&newsk->sk_user_frags, XA_FLAGS_ALLOC1);
+
return newsk;
}
EXPORT_SYMBOL(tcp_create_openreq_child);