Re: [PATCH 6/8] smb: smbdirect: add in kernel only support for IPPROTO_SMBDIRECT

From: Kuniyuki Iwashima

Date: Tue Apr 07 2026 - 21:05:51 EST


On Tue, Apr 7, 2026 at 7:47 AM Stefan Metzmacher <metze@xxxxxxxxx> wrote:
>
> For userspace callers of socket() still get -EPROTONOSUPPORT,
> so we are sure we'll only have in kernel callers, cifs.ko and
> ksmbd.ko, for now. This makes it possible to relax the
> constrains generic stream socket consumers would otherwise
> assume.
>
> There's a prototype for userspace sockets on top of
> this and there's working userspace code for Samba as
> client and server, so this is just the first step,
> but a very important one.
>
> The SMBDIRECT protocol is defined in [MS-SMBD] by Microsoft.
> It is used as wrapper around RDMA in order to provide a transport for SMB3,
> but Microsoft also uses it as transport for other protocols.
>
> SMBDIRECT works over Infiniband, RoCE and iWarp.
> RoCEv2 is based on IP/UDP and iWarp is based on IP/TCP,
> so these use IP addresses natively.
> Infiniband and RoCEv1 require IPOIB in order to be used for
> SMBDIRECT.
>
> So instead of adding a PF_SMBDIRECT, which would only use AF_INET[6],
> we use IPPROTO_SMBDIRECT instead, this uses a number not
> allocated from IANA, as it would not appear in an IP header.

Overall I don't see the upside of reusing AF_INET6? infra. It just
adds an unnecessary sk->sk_prot layer, which can be simply
implemented as sock->ops.

It seems only inet_getname() is the valid user of inet_sock.


>
> This is similar to IPPROTO_SMC, IPPROTO_MPTCP and IPPROTO_QUIC,
> which are linux specific values for the socket() syscall.

SMBDIRECT seems more like SMC, rather than MPTCP and QUIC.

SMC was implemented with AF_SMC first, and IPPROT_SMC
was added just to hook at userspace and silently convert TCP
sockets to AF_SMC sockets easily.

But the is not the case of SMBDIRECT because the socket is
only created from kernel space.


>
> socket(AF_INET, SOCK_STREAM, IPPROTO_SMBDIRECT);
> socket(AF_INET6, SOCK_STREAM, IPPROTO_SMBDIRECT);
>
> This will allow the existing smbdirect code used by
> cifs.ko and ksmbd.ko to be moved behind the socket layer,

Reusing AF_INET6? is not related to this statement as long
as the upper layer uses in-kernel socket API (sock_XXX()/kernel_XXX())
instead of calling sk->sk_prot->XXX() directly.



> so that there's less special handling. Only sock_sendmsg()
> sock_recvmsg() are used, so that the main stream handling
> is done all the same for tcp, smbdirect and later also quic.
>
> The special RDMA read/write handling will be via direct
> function calls as they are currently done for the in kernel
> consumers.
>
> For now the core smbdirect code still supports both
> modes, direct calls in indirect via the socket layer.
> The core code uses if (sc->sk.sk_family) as indication
> for the new socket mode. Once cifs.ko and ksmbd.ko
> are converted we can remove the old mode slowly,
> but I'll deferr that to a future patchset.
>
> There's still a way to go in order to make this
> as generic as tcp and quic e.g. adding MSG_SPLICE_PAGES support or
> splice_read/read_sock/read_skb.
>
> But it's a good start, which will make changes
> much easier.
>
> Cc: Steve French <smfrench@xxxxxxxxx>
> Cc: Tom Talpey <tom@xxxxxxxxxx>
> Cc: Long Li <longli@xxxxxxxxxxxxx>
> Cc: Namjae Jeon <linkinjeon@xxxxxxxxxx>
> Cc: David Howells <dhowells@xxxxxxxxxx>
> Cc: Henrique Carvalho <henrique.carvalho@xxxxxxxx>
> Cc: linux-cifs@xxxxxxxxxxxxxxx
> Cc: samba-technical@xxxxxxxxxxxxxxx
> Cc: David S. Miller <davem@xxxxxxxxxxxxx>
> Cc: Eric Dumazet <edumazet@xxxxxxxxxx>
> Cc: Jakub Kicinski <kuba@xxxxxxxxxx>
> Cc: Paolo Abeni <pabeni@xxxxxxxxxx>
> Cc: Simon Horman <horms@xxxxxxxxxx>
> Cc: Kuniyuki Iwashima <kuniyu@xxxxxxxxxx>
> Cc: Willem de Bruijn <willemb@xxxxxxxxxx>
> Cc: netdev@xxxxxxxxxxxxxxx
> Cc: Xin Long <lucien.xin@xxxxxxxxx>
> Cc: quic@xxxxxxxxxxxxxxx
> Cc: linux-rdma@xxxxxxxxxxxxxxx
> Cc: linux-kernel@xxxxxxxxxxxxxxx
> Signed-off-by: Stefan Metzmacher <metze@xxxxxxxxx>
> ---
> fs/smb/common/smbdirect/Makefile | 1 +
> fs/smb/common/smbdirect/smbdirect.h | 62 +
> fs/smb/common/smbdirect/smbdirect_accept.c | 14 +-
> .../common/smbdirect/smbdirect_connection.c | 58 +
> fs/smb/common/smbdirect/smbdirect_devices.c | 2 +-
> fs/smb/common/smbdirect/smbdirect_internal.h | 59 +-
> fs/smb/common/smbdirect/smbdirect_listen.c | 49 +-
> fs/smb/common/smbdirect/smbdirect_main.c | 45 +
> fs/smb/common/smbdirect/smbdirect_mr.c | 10 +
> fs/smb/common/smbdirect/smbdirect_proto.c | 1549 +++++++++++++++++
> fs/smb/common/smbdirect/smbdirect_public.h | 3 +
> fs/smb/common/smbdirect/smbdirect_rw.c | 29 +-
> fs/smb/common/smbdirect/smbdirect_socket.c | 147 ++
> fs/smb/common/smbdirect/smbdirect_socket.h | 26 +-
> 14 files changed, 2039 insertions(+), 15 deletions(-)
> create mode 100644 fs/smb/common/smbdirect/smbdirect_proto.c
>
> diff --git a/fs/smb/common/smbdirect/Makefile b/fs/smb/common/smbdirect/Makefile
> index 423f533e1002..fcff485d7c45 100644
> --- a/fs/smb/common/smbdirect/Makefile
> +++ b/fs/smb/common/smbdirect/Makefile
> @@ -10,6 +10,7 @@ smbdirect-y := \
> smbdirect_connection.o \
> smbdirect_mr.o \
> smbdirect_rw.o \
> + smbdirect_proto.o \
> smbdirect_debug.o \
> smbdirect_connect.o \
> smbdirect_listen.o \
> diff --git a/fs/smb/common/smbdirect/smbdirect.h b/fs/smb/common/smbdirect/smbdirect.h
> index bbab5f7f7cc9..cf3d4957f94c 100644
> --- a/fs/smb/common/smbdirect/smbdirect.h
> +++ b/fs/smb/common/smbdirect/smbdirect.h
> @@ -6,7 +6,10 @@
> #ifndef __FS_SMB_COMMON_SMBDIRECT_SMBDIRECT_H__
> #define __FS_SMB_COMMON_SMBDIRECT_SMBDIRECT_H__
>
> +#include <linux/stddef.h>
> #include <linux/types.h>
> +#include <linux/socket.h>
> +#include <asm/ioctls.h>
>
> /* SMB-DIRECT buffer descriptor V1 structure [MS-SMBD] 2.2.3.1 */
> struct smbdirect_buffer_descriptor_v1 {
> @@ -49,4 +52,63 @@ struct smbdirect_socket_parameters {
> SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB | \
> SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW)
>
> +enum {
> + __SMBDIRECT_BUFFER_REMOTE_INVALIDATE = 0x20,
> +};
> +
> +struct smbdirect_cmsg_buffer {
> + uint8_t msg_control[CMSG_SPACE(24)];
> +};
> +
> +static __always_inline
> +void __smbdirect_cmsg_prepare(struct msghdr *msg,
> + struct smbdirect_cmsg_buffer *cbuffer,
> + int cmsg_type,
> + const void *payload,
> + size_t payloadlen)
> +{
> + size_t cmsg_space = CMSG_SPACE(payloadlen);
> + size_t cmsg_len = CMSG_LEN(payloadlen);
> + struct cmsghdr *cmsg = NULL;
> + void *dataptr = NULL;
> +
> + BUILD_BUG_ON(cmsg_space > sizeof(cbuffer->msg_control));
> +
> + memset(cbuffer, 0, sizeof(*cbuffer));
> +
> + msg->msg_control = cbuffer->msg_control;
> + msg->msg_controllen = cmsg_space;
> +
> + cmsg = CMSG_FIRSTHDR(msg);
> + cmsg->cmsg_level = SOL_SMBDIRECT;
> + cmsg->cmsg_type = cmsg_type;
> + cmsg->cmsg_len = cmsg_len;
> + dataptr = CMSG_DATA(cmsg);
> + memcpy(dataptr, payload, payloadlen);
> + msg->msg_controllen = cmsg->cmsg_len;
> +}
> +
> +struct smbdirect_buffer_remote_invalidate_args {
> + __u32 remote_token;
> +} __packed;
> +#define SMBDIRECT_BUFFER_REMOTE_INVALIDATE_CMSG_TYPE \
> + _IOW('S', __SMBDIRECT_BUFFER_REMOTE_INVALIDATE, \
> + struct smbdirect_buffer_remote_invalidate_args)
> +
> +static __always_inline
> +void smbdirect_buffer_remote_invalidate_cmsg_prepare(struct msghdr *msg,
> + struct smbdirect_cmsg_buffer *cbuffer,
> + const __u32 *remote_token)
> +{
> + if (remote_token) {
> + struct smbdirect_buffer_remote_invalidate_args args = {
> + .remote_token = *remote_token,
> + };
> +
> + __smbdirect_cmsg_prepare(msg, cbuffer,
> + SMBDIRECT_BUFFER_REMOTE_INVALIDATE_CMSG_TYPE,
> + &args, sizeof(args));
> + }
> +}
> +
> #endif /* __FS_SMB_COMMON_SMBDIRECT_SMBDIRECT_H__ */
> diff --git a/fs/smb/common/smbdirect/smbdirect_accept.c b/fs/smb/common/smbdirect/smbdirect_accept.c
> index d6d5e6a3f5de..6d7d869cdbc3 100644
> --- a/fs/smb/common/smbdirect/smbdirect_accept.c
> +++ b/fs/smb/common/smbdirect/smbdirect_accept.c
> @@ -6,7 +6,6 @@
> */
>
> #include "smbdirect_internal.h"
> -#include <net/sock.h>
> #include "../../common/smb2status.h"
>
> static int smbdirect_accept_rdma_event_handler(struct rdma_cm_id *id,
> @@ -460,6 +459,12 @@ static void smbdirect_accept_negotiate_recv_work(struct work_struct *work)
> spin_lock_irqsave(&lsc->listen.lock, flags);
> list_del(&sc->accept.list);
> list_add_tail(&sc->accept.list, &lsc->listen.ready);
> + if (lsc->sk.sk_family) {
> + struct sock *lsk = &lsc->sk;
> +
> + if (!sock_flag(lsk, SOCK_DEAD) && lsk->sk_socket)
> + lsk->sk_data_ready(lsk);
> + }
> wake_up(&lsc->listen.wait_queue);
> spin_unlock_irqrestore(&lsc->listen.lock, flags);
>
> @@ -774,11 +779,13 @@ static long smbdirect_socket_wait_for_accept(struct smbdirect_socket *lsc, long
> {
> long ret;
>
> + smbdirect_socket_sk_unlock(lsc);
> ret = wait_event_interruptible_timeout(lsc->listen.wait_queue,
> !list_empty_careful(&lsc->listen.ready) ||
> lsc->status != SMBDIRECT_SOCKET_LISTENING ||
> lsc->first_error,
> timeo);
> + smbdirect_socket_sk_lock(lsc);
> if (lsc->status != SMBDIRECT_SOCKET_LISTENING)
> return -EINVAL;
> if (lsc->first_error)
> @@ -850,6 +857,11 @@ struct smbdirect_socket *smbdirect_socket_accept(struct smbdirect_socket *lsc,
> * order to grant credits to the peer.
> */
> nsc->status = SMBDIRECT_SOCKET_CONNECTED;
> + if (nsc->sk.sk_family) {
> + struct sock *nsk = &nsc->sk;
> +
> + inet_sk_set_state(nsk, TCP_ESTABLISHED);
> + }
> smbdirect_accept_negotiate_finish(nsc, 0);
>
> return nsc;
> diff --git a/fs/smb/common/smbdirect/smbdirect_connection.c b/fs/smb/common/smbdirect/smbdirect_connection.c
> index 1e946f78e935..2c426aefd16d 100644
> --- a/fs/smb/common/smbdirect/smbdirect_connection.c
> +++ b/fs/smb/common/smbdirect/smbdirect_connection.c
> @@ -153,6 +153,15 @@ void smbdirect_connection_rdma_established(struct smbdirect_socket *sc)
>
> sc->rdma.cm_id->event_handler = smbdirect_connection_rdma_event_handler;
> sc->rdma.expected_event = RDMA_CM_EVENT_DISCONNECTED;
> +
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + smbdirect_socket_sync_saddr_to_sk(sc, NULL);
> + smbdirect_socket_sync_daddr_to_sk(sc);
> +
> + inet_sk_set_state(sk, TCP_SYN_RECV);
> + }
> }
>
> void smbdirect_connection_negotiation_done(struct smbdirect_socket *sc)
> @@ -189,6 +198,13 @@ void smbdirect_connection_negotiation_done(struct smbdirect_socket *sc)
> smbdirect_socket_status_string(sc->status),
> SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
> sc->status = SMBDIRECT_SOCKET_CONNECTED;
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + inet_sk_set_state(sk, TCP_ESTABLISHED);
> + if (!sock_flag(sk, SOCK_DEAD) && sk->sk_socket)
> + sk->sk_socket->state = SS_CONNECTED;
> + }
>
> /*
> * We need to setup the refill and send immediate work
> @@ -203,6 +219,13 @@ void smbdirect_connection_negotiation_done(struct smbdirect_socket *sc)
> &sc->rdma.cm_id->route.addr.src_addr,
> &sc->rdma.cm_id->route.addr.dst_addr);
>
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + if (!sock_flag(sk, SOCK_DEAD) && sk->sk_socket)
> + sk->sk_state_change(sk);
> + }
> +
> wake_up(&sc->status_wait);
> }
>
> @@ -739,10 +762,12 @@ int smbdirect_connection_wait_for_connected(struct smbdirect_socket *sc)
> "waiting for connection: device: %.*s local: %pISpsfc remote: %pISpsfc\n",
> IB_DEVICE_NAME_MAX, devname, src, dst);
>
> + smbdirect_socket_sk_unlock(sc);
> ret = wait_event_interruptible_timeout(sc->status_wait,
> sc->status == SMBDIRECT_SOCKET_CONNECTED ||
> sc->first_error,
> msecs_to_jiffies(sp->negotiate_timeout_msec));
> + smbdirect_socket_sk_lock(sc);
> if (sc->rdma.cm_id) {
> /*
> * Maybe src and dev are updated in the meantime.
> @@ -954,6 +979,12 @@ int smbdirect_connection_send_batch_flush(struct smbdirect_socket *sc,
> atomic_add(batch->credit, &sc->send_io.bcredits.count);
> batch->credit = 0;
> wake_up(&sc->send_io.bcredits.wait_queue);
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + if (!sock_flag(sk, SOCK_DEAD) && sk->sk_socket)
> + sk->sk_write_space(sk);
> + }
> }
>
> return ret;
> @@ -1091,6 +1122,8 @@ int smbdirect_connection_send_single_iter(struct smbdirect_socket *sc,
> u32 data_length = 0;
> int ret;
>
> + smbdirect_socket_sk_owned_by_me(sc);
> +
> if (WARN_ON_ONCE(flags))
> return -EINVAL; /* no flags support for now */
>
> @@ -1150,10 +1183,12 @@ int smbdirect_connection_send_single_iter(struct smbdirect_socket *sc,
> * wait until either the refill work or the peer
> * granted new credits
> */
> + smbdirect_socket_sk_unlock(sc);
> ret = wait_event_interruptible(sc->send_io.credits.wait_queue,
> atomic_read(&sc->send_io.credits.count) >= 1 ||
> atomic_read(&sc->recv_io.credits.available) >= 1 ||
> sc->status != SMBDIRECT_SOCKET_CONNECTED);
> + smbdirect_socket_sk_lock(sc);
> if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
> ret = -ENOTCONN;
> if (ret < 0)
> @@ -1268,9 +1303,11 @@ int smbdirect_connection_send_wait_zero_pending(struct smbdirect_socket *sc)
> * that means all the I/Os have been out and we are good to return
> */
>
> + smbdirect_socket_sk_unlock(sc);
> wait_event(sc->send_io.pending.zero_wait_queue,
> atomic_read(&sc->send_io.pending.count) == 0 ||
> sc->status != SMBDIRECT_SOCKET_CONNECTED);
> + smbdirect_socket_sk_lock(sc);
> if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
> smbdirect_log_write(sc, SMBDIRECT_LOG_ERR,
> "status=%s first_error=%1pe => %1pe\n",
> @@ -1297,6 +1334,8 @@ int smbdirect_connection_send_iter(struct smbdirect_socket *sc,
> int error = 0;
> __be32 hdr;
>
> + smbdirect_socket_sk_owned_by_me(sc);
> +
> if (WARN_ONCE(flags, "unexpected flags=0x%x\n", flags))
> return -EINVAL; /* no flags support for now */
>
> @@ -1448,7 +1487,9 @@ static void smbdirect_connection_send_immediate_work(struct work_struct *work)
> smbdirect_log_keep_alive(sc, SMBDIRECT_LOG_INFO,
> "send an empty message\n");
> sc->statistics.send_empty++;
> + smbdirect_socket_sk_lock(sc);
> ret = smbdirect_connection_send_single_iter(sc, NULL, NULL, 0, 0);
> + smbdirect_socket_sk_unlock(sc);
> if (ret < 0) {
> smbdirect_log_write(sc, SMBDIRECT_LOG_ERR,
> "smbdirect_connection_send_single_iter ret=%1pe\n",
> @@ -1632,6 +1673,12 @@ void smbdirect_connection_recv_io_done(struct ib_cq *cq, struct ib_wc *wc)
> * If any sender is waiting for credits, unblock it
> */
> wake_up(&sc->send_io.credits.wait_queue);
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + if (!sock_flag(sk, SOCK_DEAD) && sk->sk_socket)
> + sk->sk_write_space(sk);
> + }
> }
>
> /* Send an immediate response right away if requested */
> @@ -1652,6 +1699,12 @@ void smbdirect_connection_recv_io_done(struct ib_cq *cq, struct ib_wc *wc)
>
> smbdirect_connection_reassembly_append_recv_io(sc, recv_io, data_length);
> wake_up(&sc->recv_io.reassembly.wait_queue);
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + if (!sock_flag(sk, SOCK_DEAD) && sk->sk_socket)
> + sk->sk_data_ready(sk);
> + }
> } else
> smbdirect_connection_put_recv_io(recv_io);
>
> @@ -1735,6 +1788,9 @@ int smbdirect_connection_recv_io_refill(struct smbdirect_socket *sc)
> /*
> * If the last send credit is waiting for credits
> * it can grant we need to wake it up
> + *
> + * This needs to wake up smbdirect_connection_send_single_iter()
> + * only, so we don't need sk->sk_write_space() here.
> */
> if (atomic_read(&sc->send_io.bcredits.count) == 0 &&
> atomic_read(&sc->send_io.credits.count) == 0)
> @@ -1922,9 +1978,11 @@ int smbdirect_connection_recvmsg(struct smbdirect_socket *sc,
>
> smbdirect_log_read(sc, SMBDIRECT_LOG_INFO,
> "wait_event on more data\n");
> + smbdirect_socket_sk_unlock(sc);
> ret = wait_event_interruptible(sc->recv_io.reassembly.wait_queue,
> sc->recv_io.reassembly.data_length >= size ||
> sc->status != SMBDIRECT_SOCKET_CONNECTED);
> + smbdirect_socket_sk_lock(sc);
> /* Don't return any data if interrupted */
> if (ret)
> return ret;
> diff --git a/fs/smb/common/smbdirect/smbdirect_devices.c b/fs/smb/common/smbdirect/smbdirect_devices.c
> index aaab99e9c045..da0edc104e48 100644
> --- a/fs/smb/common/smbdirect/smbdirect_devices.c
> +++ b/fs/smb/common/smbdirect/smbdirect_devices.c
> @@ -257,7 +257,7 @@ __init int smbdirect_devices_init(void)
> return 0;
> }
>
> -__exit void smbdirect_devices_exit(void)
> +__cold void smbdirect_devices_exit(void)
> {
> struct smbdirect_device *sdev, *tmp;
>
> diff --git a/fs/smb/common/smbdirect/smbdirect_internal.h b/fs/smb/common/smbdirect/smbdirect_internal.h
> index 30a1b8643657..517ff0533032 100644
> --- a/fs/smb/common/smbdirect/smbdirect_internal.h
> +++ b/fs/smb/common/smbdirect/smbdirect_internal.h
> @@ -12,8 +12,6 @@
> #include "smbdirect_pdu.h"
> #include "smbdirect_public.h"
>
> -#include <linux/mutex.h>
> -
> struct smbdirect_module_state {
> struct mutex mutex;
>
> @@ -30,6 +28,8 @@ struct smbdirect_module_state {
> rwlock_t lock;
> struct list_head list;
> } devices;
> +
> + struct smbdirect_socket_parameters default_parameters;
> };
>
> extern struct smbdirect_module_state smbdirect_globals;
> @@ -46,10 +46,58 @@ struct smbdirect_device {
> char ib_name[IB_DEVICE_NAME_MAX];
> };
>
> +static __always_inline void smbdirect_socket_sk_owned_by_me(struct smbdirect_socket *sc)
> +{
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + /* assert it is already locked */
> + sock_owned_by_me(sk);
> + }
> +}
> +
> +static __always_inline void smbdirect_socket_sk_not_owned_by_me(struct smbdirect_socket *sc)
> +{
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> + }
> +}
> +
> +static __always_inline void smbdirect_socket_sk_lock(struct smbdirect_socket *sc)
> +{
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + lock_sock(sk);
> + }
> +}
> +
> +static __always_inline void smbdirect_socket_sk_unlock(struct smbdirect_socket *sc)
> +{
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + /* assert it is already locked */
> + sock_owned_by_me(sk);
> +
> + release_sock(sk);
> + }
> +}
> +
> int smbdirect_socket_init_new(struct net *net, struct smbdirect_socket *sc);
>
> int smbdirect_socket_init_accepting(struct rdma_cm_id *id, struct smbdirect_socket *sc);
>
> +int smbdirect_socket_sync_saddr_to_sk(struct smbdirect_socket *sc, bool *_is_any_addr);
> +
> +int smbdirect_socket_sync_daddr_to_sk(struct smbdirect_socket *sc);
> +
> void __smbdirect_socket_schedule_cleanup(struct smbdirect_socket *sc,
> const char *macro_name,
> unsigned int lvl,
> @@ -135,7 +183,12 @@ int smbdirect_accept_connect_request(struct smbdirect_socket *sc,
>
> void smbdirect_accept_negotiate_finish(struct smbdirect_socket *sc, u32 ntstatus);
>
> +void smbdirect_sk_reclassify(struct sock *sk);
> +
> __init int smbdirect_devices_init(void);
> -__exit void smbdirect_devices_exit(void);
> +__cold void smbdirect_devices_exit(void);
> +
> +__init int smbdirect_proto_init(void);
> +__exit void smbdirect_proto_exit(void);
>
> #endif /* __FS_SMB_COMMON_SMBDIRECT_INTERNAL_H__ */
> diff --git a/fs/smb/common/smbdirect/smbdirect_listen.c b/fs/smb/common/smbdirect/smbdirect_listen.c
> index 05c7902e7020..a6e08d82dc73 100644
> --- a/fs/smb/common/smbdirect/smbdirect_listen.c
> +++ b/fs/smb/common/smbdirect/smbdirect_listen.c
> @@ -74,6 +74,12 @@ int smbdirect_socket_listen(struct smbdirect_socket *sc, int backlog)
> */
> sc->listen.backlog = backlog;
>
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + inet_sk_set_state(sk, TCP_LISTEN);
> + }
> +
> if (sc->rdma.cm_id->device)
> smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
> "listening on addr: %pISpsfc dev: %.*s\n",
> @@ -209,6 +215,7 @@ static int smbdirect_listen_connect_request(struct smbdirect_socket *lsc,
> const struct rdma_cm_event *event)
> {
> const struct smbdirect_socket_parameters *lsp = &lsc->parameters;
> + struct sock *nsk = NULL;
> struct smbdirect_socket *nsc;
> unsigned long flags;
> size_t backlog = max_t(size_t, 1, lsc->listen.backlog);
> @@ -265,9 +272,39 @@ static int smbdirect_listen_connect_request(struct smbdirect_socket *lsc,
> return -EBUSY;
> }
>
> - ret = smbdirect_socket_create_accepting(new_id, &nsc);
> - if (ret)
> - goto socket_init_failed;
> + if (lsc->sk.sk_family) {
> + struct sock *lsk = &lsc->sk;
> +
> + ret = -ENOMEM;
> + nsk = sk_clone(lsk, lsk->sk_allocation, false);
> + if (!nsk)
> + goto sk_clone_failed;
> + /* sk_clone_lock() increments refcnt to 2; drop the extra. */
> + __sock_put(nsk);
> + /* sk_clone() already called sk_sockets_allocated_inc(sk); */
> + sock_prot_inuse_add(sock_net(nsk), nsk->sk_prot, 1);
> +
> + smbdirect_sk_reclassify(nsk);
> + inet_sk_set_state(nsk, TCP_SYN_RECV);
> + nsc = smbdirect_socket_from_sk(nsk);
> +
> + ret = smbdirect_socket_init_accepting(new_id, nsc);
> + if (ret)
> + goto socket_init_failed;
> +
> + /*
> + * Note that smbdirect_sock_accept() will set
> + * SOCK_CUSTOM_SOCKOPT once [__]inet_accept()
> + * called sk_set_socket() via sock_graft().
> + */
> + WARN_ON_ONCE(nsc->orig_sk_destruct != lsc->orig_sk_destruct);
> + WARN_ON_ONCE(nsk->sk_destruct != lsk->sk_destruct);
> + WARN_ON_ONCE(nsk->sk_ipv6only != lsk->sk_ipv6only);
> + } else {
> + ret = smbdirect_socket_create_accepting(new_id, &nsc);
> + if (ret)
> + goto socket_init_failed;
> + }
>
> nsc->logging = lsc->logging;
> ret = smbdirect_socket_set_initial_parameters(nsc, &lsc->parameters);
> @@ -302,7 +339,11 @@ static int smbdirect_listen_connect_request(struct smbdirect_socket *lsc,
> */
> nsc->ib.dev = NULL;
> nsc->rdma.cm_id = NULL;
> - smbdirect_socket_release(nsc);
> + if (!nsk)
> + smbdirect_socket_release(nsc);
> socket_init_failed:
> + if (nsk)
> + sk_free(nsk);
> +sk_clone_failed:
> return ret;
> }
> diff --git a/fs/smb/common/smbdirect/smbdirect_main.c b/fs/smb/common/smbdirect/smbdirect_main.c
> index fe6e8d93c34c..ccbe979332af 100644
> --- a/fs/smb/common/smbdirect/smbdirect_main.c
> +++ b/fs/smb/common/smbdirect/smbdirect_main.c
> @@ -12,6 +12,7 @@ struct smbdirect_module_state smbdirect_globals = {
>
> static __init int smbdirect_module_init(void)
> {
> + struct smbdirect_socket_parameters *sp;
> int ret = -ENOMEM;
>
> pr_notice("subsystem loading...\n");
> @@ -73,10 +74,52 @@ static __init int smbdirect_module_init(void)
> if (ret)
> goto devices_init_failed;
>
> + /*
> + * Create the global default parameters
> + */
> + sp = &smbdirect_globals.default_parameters;
> + sp->resolve_addr_timeout_msec = 5 * 1000;
> + sp->resolve_route_timeout_msec = 5 * 1000;
> + sp->rdma_connect_timeout_msec = 5 * 1000;
> + sp->negotiate_timeout_msec = 120 * 1000;
> + sp->initiator_depth = 1; /* the server should change this */
> + sp->responder_resources = 1; /* the client should change this */
> + sp->recv_credit_max = 255;
> + sp->send_credit_target = 255;
> + sp->max_send_size = 1364;
> + /*
> + * The maximum fragmented upper-layer payload receive size supported
> + *
> + * Assume max_payload_per_credit is
> + * smbd_max_receive_size - 24 = 1340
> + *
> + * The maximum number would be
> + * smbd_receive_credit_max * max_payload_per_credit
> + *
> + * 1340 * 255 = 341700 (0x536C4)
> + *
> + * The minimum value from the spec is 131072 (0x20000)
> + *
> + * For now we use the logic we used before:
> + * (1364 * 255) / 2 = 173910 (0x2A756)
> + */
> + sp->max_fragmented_recv_size = (1364 * 255) / 2;
> + sp->max_recv_size = 1364;
> + sp->max_read_write_size = 0; /* the server should change this */
> + sp->max_frmr_depth = 0; /* the client should change this */
> + sp->keepalive_interval_msec = 120 * 1000;
> + sp->keepalive_timeout_msec = 5 * 1000;
> +
> + ret = smbdirect_proto_init();
> + if (ret)
> + goto proto_init_failed;
> +
> mutex_unlock(&smbdirect_globals.mutex);
> pr_notice("subsystem loaded\n");
> return 0;
>
> +proto_init_failed:
> + smbdirect_devices_exit();
> devices_init_failed:
> destroy_workqueue(smbdirect_globals.workqueues.cleanup);
> alloc_cleanup_wq_failed:
> @@ -101,6 +144,8 @@ static __exit void smbdirect_module_exit(void)
> pr_notice("subsystem unloading...\n");
> mutex_lock(&smbdirect_globals.mutex);
>
> + smbdirect_proto_exit();
> +
> smbdirect_devices_exit();
>
> destroy_workqueue(smbdirect_globals.workqueues.accept);
> diff --git a/fs/smb/common/smbdirect/smbdirect_mr.c b/fs/smb/common/smbdirect/smbdirect_mr.c
> index fa9be8089925..86bb72ed10ae 100644
> --- a/fs/smb/common/smbdirect/smbdirect_mr.c
> +++ b/fs/smb/common/smbdirect/smbdirect_mr.c
> @@ -167,9 +167,11 @@ smbdirect_connection_get_mr_io(struct smbdirect_socket *sc)
> int ret;
>
> again:
> + smbdirect_socket_sk_unlock(sc);
> ret = wait_event_interruptible(sc->mr_io.ready.wait_queue,
> atomic_read(&sc->mr_io.ready.count) ||
> sc->status != SMBDIRECT_SOCKET_CONNECTED);
> + smbdirect_socket_sk_lock(sc);
> if (ret) {
> smbdirect_log_rdma_mr(sc, SMBDIRECT_LOG_ERR,
> "wait_event_interruptible ret=%d (%1pe)\n",
> @@ -281,7 +283,9 @@ smbdirect_connection_register_mr_io(struct smbdirect_socket *sc,
> return NULL;
> }
>
> + smbdirect_socket_sk_lock(sc);
> mr = smbdirect_connection_get_mr_io(sc);
> + smbdirect_socket_sk_unlock(sc);
> if (!mr) {
> smbdirect_log_rdma_mr(sc, SMBDIRECT_LOG_ERR,
> "smbdirect_connection_get_mr_io returning NULL\n");
> @@ -415,6 +419,12 @@ void smbdirect_connection_deregister_mr_io(struct smbdirect_mr_io *mr)
> if (mr->state == SMBDIRECT_MR_DISABLED)
> goto put_kref;
>
> + /*
> + * We are protected by mr->mutex
> + * without lock_sock().
> + */
> + smbdirect_socket_sk_not_owned_by_me(sc);
> +
> if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
> smbdirect_mr_io_disable_locked(mr);
> goto put_kref;
> diff --git a/fs/smb/common/smbdirect/smbdirect_proto.c b/fs/smb/common/smbdirect/smbdirect_proto.c
> new file mode 100644
> index 000000000000..1a832d52eb89
> --- /dev/null
> +++ b/fs/smb/common/smbdirect/smbdirect_proto.c
> @@ -0,0 +1,1549 @@
> +// SPDX-License-Identifier: GPL-2.0-or-later
> +/*
> + * Copyright (c) 2025 Stefan Metzmacher
> + */
> +
> +#include "smbdirect_internal.h"
> +#include <net/protocol.h>
> +#include <net/inet_common.h>
> +#include <linux/bpf-cgroup.h>
> +#include <linux/errname.h>
> +
> +#define SMBDIRECT_FN_GENERIC(__sk, __fmt, __args...) do { \
> + struct smbdirect_socket *__sc = smbdirect_socket_from_sk(__sk); \
> + __smbdirect_log_generic(__sc, SMBDIRECT_LOG_INFO, SMBDIRECT_LOG_SK, \
> + __fmt " sc=%p %s first_error=%1pe kern=%u locked=%u refs=%u dead=%u mrefs=%u\n", \
> + ##__args, __sc, \
> + smbdirect_socket_status_string(__sc->status), \
> + SMBDIRECT_DEBUG_ERR_PTR(__sc->first_error), \
> + (__sk)->sk_kern_sock, \
> + sock_owned_by_user_nocheck(__sk), \
> + refcount_read(&((__sk)->sk_refcnt)), \
> + sock_flag(__sk, SOCK_DEAD), \
> + module_refcount(THIS_MODULE)); \
> +} while (0)
> +
> +#define SMBDIRECT_FN_COMMENT(__sk, __comment) \
> + SMBDIRECT_FN_GENERIC(__sk, "%s with", __comment)
> +
> +#define SMBDIRECT_FN_CALLED(__sk) \
> + SMBDIRECT_FN_GENERIC(__sk, "Called for")
> +
> +#define SMBDIRECT_FN_RETURN_VOID(__sk) \
> + SMBDIRECT_FN_GENERIC(__sk, "Returning for")
> +
> +#define SMBDIRECT_FN_RETURN_POLL(__sk, __mask) \
> + SMBDIRECT_FN_GENERIC(__sk, "Returning mask=0x%x for", __mask)
> +
> +#define SMBDIRECT_FN_RETURN_INT(__sk, __ret) do { \
> + bool __is_err = IS_ERR(SMBDIRECT_DEBUG_ERR_PTR(__ret)); \
> + SMBDIRECT_FN_GENERIC(__sk, "Returning ret=%d%s%s%s for", \
> + (__ret), \
> + __is_err ? " (" : "", \
> + __is_err ? errname(__ret) : "", \
> + __is_err ? ")" : ""); \
> +} while (0)
> +
> +static bool smbdirect_sk_logging_needed(struct smbdirect_socket *sc,
> + void *private_ptr,
> + unsigned int lvl,
> + unsigned int cls)
> +{
> + /*
> + * Only errors by default.
> + */
> + if (lvl <= SMBDIRECT_LOG_ERR)
> + return true;
> + return false;
> +}
> +
> +static void smbdirect_sk_logging_vaprintf(struct smbdirect_socket *sc,
> + const char *func,
> + unsigned int line,
> + void *private_ptr,
> + unsigned int lvl,
> + unsigned int cls,
> + struct va_format *vaf)
> +{
> + if (lvl <= SMBDIRECT_LOG_ERR)
> + pr_err("%s:%u %pV", func, line, vaf);
> + else
> + pr_info("%s:%u %pV", func, line, vaf);
> +}
> +
> +void smbdirect_sk_reclassify(struct sock *sk)
> +{
> +#ifdef CONFIG_DEBUG_LOCK_ALLOC
> + static struct lock_class_key sk_key[2];
> + static struct lock_class_key slock_key[2];
> +
> + if (WARN_ON_ONCE(!sock_allow_reclassification(sk)))
> + return;
> +
> + switch (sk->sk_family) {
> + case AF_INET:
> + /*
> + * Before we reset the owner we
> + * need to drop the reference of the
> + * existing module, this is only
> + * really relevant for AF_INET,
> + * as that is always builtin
> + * there's no potential leak
> + * of module references. We do it
> + * mainly in order to match the
> + * AF_INET6 case.
> + */
> + sk_owner_put(sk);
> + sk_owner_clear(sk);
> +
> + sock_lock_init_class_and_name(sk,
> + "slock-AF_INET-IPPROTO-SMBDIRECT",
> + &slock_key[0],
> + "sk_lock-AF_INET-IPPROTO-SMBDIRECT",
> + &sk_key[0]);
> +
> + /*
> + * Now that we reclassified the socket
> + * we're also the new sk_owner, but that's
> + * not needed as there's still a reference
> + * on sk->sk_prot->owner, which is dropped
> + * in sk_prot_free(). But in order to
> + * avoid module reference leaks to our
> + * own module we need to put and clear
> + * sk_owner, in order to allow callers
> + * to do their own reclassification.
> + */
> + sk_owner_put(sk);
> + sk_owner_clear(sk);
> + break;
> + case AF_INET6:
> + /*
> + * Before we reset the owner we
> + * need to drop the reference of the
> + * existing module.
> + *
> + * As we also use inet6_register_protosw()
> + * and other symbols from a possible
> + * ipv6.ko, we already have enough
> + * module references in order to avoid
> + * unloading of ipv6.ko, while smbdirect.ko
> + * is loaded.
> + *
> + * However when smbdirect.ko is unloaded
> + * we should not leak references in order
> + * to allow ipv6.ko to be unloaded as well.
> + */
> + sk_owner_put(sk);
> + sk_owner_clear(sk);
> +
> + sock_lock_init_class_and_name(sk,
> + "slock-AF_INET6-IPPROTO-SMBDIRECT",
> + &slock_key[1],
> + "sk_lock-AF_INET6-IPPROTO-SMBDIRECT",
> + &sk_key[1]);
> +
> + /*
> + * Now that we reclassified the socket
> + * we're also the new sk_owner, but that's
> + * not needed as there's still a reference
> + * on sk->sk_prot->owner, which is dropped
> + * in sk_prot_free(). But in order to
> + * avoid module reference leaks to our
> + * own module we need to put and clear
> + * sk_owner, in order to allow callers
> + * to do their own reclassification.
> + */
> + sk_owner_put(sk);
> + sk_owner_clear(sk);
> + break;
> + default:
> + WARN_ON_ONCE(1);
> + }
> +#endif /* CONFIG_DEBUG_LOCK_ALLOC */
> +}
> +
> +static void smbdirect_sk_destruct(struct sock *sk)
> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> +
> + /*
> + * Called by sk_free()
> + */
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + if (WARN_ON_ONCE(sc->status != SMBDIRECT_SOCKET_DESTROYED)) {
> + pr_err("Attempt to release SMBDIRECT socket in status %s sc %p\n",
> + smbdirect_socket_status_string(sc->status), sc);
> + SMBDIRECT_FN_RETURN_VOID(sk);
> + return;
> + }
> +
> + SMBDIRECT_FN_COMMENT(sk, "calling orig_sk_destruct");
> + smbdirect_log_sk(sc, SMBDIRECT_LOG_INFO,
> + "sc[%p]->orig_sk_destruct[%ps]\n",
> + sc, sc->orig_sk_destruct);
> + sc->orig_sk_destruct(sk);
> + SMBDIRECT_FN_RETURN_VOID(sk);
> +}
> +
> +static int smbdirect_sk_init(struct sock *sk)
> +{
> + struct socket *sock = sk->sk_socket;
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + const struct smbdirect_socket_parameters *sp = &smbdirect_globals.default_parameters;
> + void (*orig_sk_destruct)(struct sock *sk);
> + int ret;
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + smbdirect_sk_reclassify(sk);
> +
> + smbdirect_socket_init(sc);
> + smbdirect_socket_set_logging(sc,
> + NULL,
> + smbdirect_sk_logging_needed,
> + smbdirect_sk_logging_vaprintf);
> +
> + smbdirect_log_sk(sc, SMBDIRECT_LOG_INFO,
> + "Called for sk=%p family=%u protocol=%u type=%u\n",
> + sk, sk->sk_family, sk->sk_protocol, sk->sk_type);
> +
> + sk_sockets_allocated_inc(sk);
> + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
> +
> + orig_sk_destruct = sk->sk_destruct;
> + SMBDIRECT_FN_COMMENT(sk, "remembered orig_sk_destruct");
> + smbdirect_log_sk(sc, SMBDIRECT_LOG_INFO,
> + "sc[%p]->orig_sk_destruct[%ps]\n",
> + sc, orig_sk_destruct);
> + sc->orig_sk_destruct = orig_sk_destruct;
> + sk->sk_destruct = smbdirect_sk_destruct;
> +
> + /*
> + * We want to handle all sockopts explicitly
> + * and only support what we really support.
> + */
> + set_bit(SOCK_CUSTOM_SOCKOPT, &sock->flags);
> + /*
> + * There are no legacy callers, so we are strict
> + * regarding ipv4 vs. ipv6.
> + */
> + sk->sk_ipv6only = true;
> +
> + /*
> + * No userspace sockets yet...
> + */
> + if (!sk->sk_kern_sock) {
> + sc->first_error = -EPROTONOSUPPORT;
> + SMBDIRECT_FN_COMMENT(sk, "No userspace sockets");
> + return -EPROTONOSUPPORT;
> + }
> +
> + ret = smbdirect_socket_init_new(sock_net(sk), sc);
> + if (ret)
> + goto socket_init_failed;
> + /*
> + * smbdirect_socket_init_new() called smbdirect_socket_init() again,
> + * so we need to call smbdirect_socket_set_logging() again!
> + */
> + smbdirect_socket_set_logging(sc,
> + NULL,
> + smbdirect_sk_logging_needed,
> + smbdirect_sk_logging_vaprintf);
> +
> + WARN_ON_ONCE(sc->orig_sk_destruct != orig_sk_destruct);
> + WARN_ON_ONCE(sk->sk_destruct != smbdirect_sk_destruct);
> +
> + ret = smbdirect_socket_set_initial_parameters(sc, sp);
> + if (ret)
> + goto set_params_failed;
> +
> + ret = smbdirect_socket_set_kernel_settings(sc, IB_POLL_SOFTIRQ, sk->sk_allocation);
> + if (ret)
> + goto set_settings_failed;
> +
> + SMBDIRECT_FN_RETURN_INT(sk, 0);
> + return 0;
> +
> +set_settings_failed:
> +set_params_failed:
> +socket_init_failed:
> + sc->first_error = ret;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static void smbdirect_sk_destroy(struct sock *sk)
> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is already locked */
> + sock_owned_by_me(sk);
> +
> + /*
> + * For now do a sync disconnect/destroy
> + *
> + * SMBDIRECT_LOG_INFO is enough here
> + * as this is the typical case where
> + * we terminate the connection ourself.
> + */
> + smbdirect_socket_schedule_cleanup_lvl(sc,
> + SMBDIRECT_LOG_INFO,
> + -ESHUTDOWN);
> + smbdirect_socket_destroy_sync(sc);
> +
> + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
> + sk_sockets_allocated_dec(sk);
> +
> + SMBDIRECT_FN_RETURN_VOID(sk);
> +}
> +
> +static int smbdirect_sk_hash(struct sock *sk)
> +{
> + SMBDIRECT_FN_CALLED(sk);
> + return 0;
> +}

It seems this was implemented just to fill all function
pointers of sk->sk_prot but looks unnecessary.

Same for other NOP functions, unhash(), release_cb(), etc.


> +
> +static void smbdirect_sk_unhash(struct sock *sk)
> +{
> + SMBDIRECT_FN_CALLED(sk);
> +}
> +
> +static void smbdirect_sk_release_cb(struct sock *sk)
> +{
> + /*
> + * Called from release_sock()
> + */
> + SMBDIRECT_FN_CALLED(sk);
> +}
> +
> +static int smbdirect_sk_pre_bind(struct sock *sk,
> + struct sockaddr_unsized *uaddr,
> + int *addr_len,
> + u32 *flags,
> + u16 *port)
> +{
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + if (*addr_len < sizeof(uaddr->sa_family))
> + return -EINVAL;
> +
> + /* AF_UNSPEC is not allowed */
> + if (sk->sk_family != uaddr->sa_family)
> + return -EAFNOSUPPORT;
> +
> + /*
> + * BPF prog is run before any checks are done so that if the prog
> + * changes context in a wrong way it will be caught.
> + */
> + switch (sk->sk_family) {
> + case AF_INET:
> + if (*addr_len < sizeof(struct sockaddr_in))
> + return -EINVAL;
> +
> + *port = ntohs(((struct sockaddr_in *)uaddr)->sin_port);
> +
> + return BPF_CGROUP_RUN_PROG_INET_BIND_LOCK(sk, uaddr, addr_len,
> + CGROUP_INET4_BIND,
> + flags);

Do you really need these bpf hooks ?

It seems the smb sockets can be created from kthread
and tied to the root cgroup.


> + case AF_INET6:
> + /*
> + * We require a full struct sockaddr_in6 (28 bytes) instead of a
> + * minimal size of SIN6_LEN_RFC2133 (24 bytes), as we don't
> + * have any legacy callers in userspace and the
> + * rdma layer also expects that.
> + */
> + if (*addr_len < sizeof(struct sockaddr_in6))
> + return -EINVAL;
> +
> + *port = ntohs(((struct sockaddr_in6 *)uaddr)->sin6_port);
> +
> + return BPF_CGROUP_RUN_PROG_INET_BIND_LOCK(sk, uaddr, addr_len,
> + CGROUP_INET6_BIND,
> + flags);
> + }
> +
> + return -EAFNOSUPPORT;
> +}
> +
> +static int smbdirect_sk_do_bind(struct sock *sk,
> + struct sockaddr_unsized *uaddr,
> + const u32 flags,
> + const u16 port)
> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + bool is_any_addr = true;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + if (flags & BIND_WITH_LOCK)
> + sock_owned_by_me(sk);
> + else
> + sock_not_owned_by_me(sk);
> +
> + ret = smbdirect_socket_bind(sc, (struct sockaddr *)uaddr);
> + if (ret) {
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + ret = smbdirect_socket_sync_saddr_to_sk(sc, &is_any_addr);
> + if (ret) {
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + /* Make sure we are allowed to bind here. */
> + if (sk->sk_num && !(flags & BIND_FROM_BPF)) {
> + switch (sk->sk_family) {
> + case AF_INET:
> + ret = BPF_CGROUP_RUN_PROG_INET4_POST_BIND(sk);
> + if (ret) {
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> + break;
> +
> + case AF_INET6:
> + ret = BPF_CGROUP_RUN_PROG_INET6_POST_BIND(sk);
> + if (ret) {
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> + break;
> + }
> + }
> +
> + if (!is_any_addr)
> + sk->sk_userlocks |= SOCK_BINDADDR_LOCK;
> + if (port)
> + sk->sk_userlocks |= SOCK_BINDPORT_LOCK;

Can this socket be passed to SOCK_BINDPORT_LOCK user,
inet_bhash2_reset_saddr(), inet_sk_rebuild_header() ?


> +
> + ret = 0;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sk_bind(struct sock *sk, struct sockaddr_unsized *addr, int addr_len)
> +{
> + struct net *net = sock_net(sk);
> + u32 flags = BIND_WITH_LOCK;
> + u16 port = 0;
> + u16 check_port = 0;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + ret = smbdirect_sk_pre_bind(sk, addr, &addr_len, &flags, &port);
> + if (ret) {
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + /*
> + * treat the iwarp tcp port for
> + * smb (5445) as the main smb port (445)
> + * and only allow the bind if 445
> + * would be allowed.
> + */
> + if (port == 5445)
> + check_port = 445;
> + else
> + check_port = port;
> +
> + if (!(flags & BIND_NO_CAP_NET_BIND_SERVICE) &&
> + check_port && inet_port_requires_bind_service(net, check_port) &&
> + !ns_capable(net->user_ns, CAP_NET_BIND_SERVICE)) {
> + ret = -EACCES;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + if (flags & BIND_WITH_LOCK)
> + lock_sock(sk);

Is connect() called without bind() and could a bpf prog
calls bpf_bind() for this socket ?


> +
> + ret = smbdirect_sk_do_bind(sk, addr, flags, port);
> +
> + if (flags & BIND_WITH_LOCK)
> + release_sock(sk);
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static struct sock *smbdirect_sk_accept(struct sock *lsk, struct proto_accept_arg *arg)
> +{
> + struct smbdirect_socket *lsc = smbdirect_socket_from_sk(lsk);
> + long timeo = sock_rcvtimeo(lsk, arg->flags & O_NONBLOCK);
> + struct smbdirect_socket *nsc;
> + struct sock *nsk;
> +
> + SMBDIRECT_FN_CALLED(lsk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(lsk);
> +
> + lock_sock(lsk);
> + nsc = smbdirect_socket_accept(lsc, timeo, arg);
> + release_sock(lsk);
> + if (!nsc) {
> + SMBDIRECT_FN_RETURN_INT(lsk, arg->err);
> + return NULL;
> + }
> + nsk = &nsc->sk;
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(nsk);

Looks redundant.


> +
> + SMBDIRECT_FN_RETURN_INT(lsk, 0);
> + return nsk;
> +}
> +
> +static int smbdirect_sk_pre_connect(struct sock *sk, struct sockaddr_unsized *uaddr, int addr_len)
> +{
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is already locked */
> + sock_owned_by_me(sk);
> +
> + if (addr_len < sizeof(uaddr->sa_family))
> + return -EINVAL;
> +
> + if (sk->sk_family != uaddr->sa_family)
> + return -EAFNOSUPPORT;
> +
> + switch (sk->sk_family) {
> + case AF_INET:
> + if (addr_len < sizeof(struct sockaddr_in))
> + return -EINVAL;
> +
> + return BPF_CGROUP_RUN_PROG_INET4_CONNECT(sk, uaddr, &addr_len);
> + case AF_INET6:
> + /*
> + * We require a full struct sockaddr_in6 (28 bytes) instead of a
> + * minimal size of SIN6_LEN_RFC2133 (24 bytes), as we don't
> + * have any legacy callers in userspace and the
> + * rdma layer also expects that.
> + */
> + if (addr_len < sizeof(struct sockaddr_in6))
> + return -EINVAL;
> +
> + return BPF_CGROUP_RUN_PROG_INET6_CONNECT(sk, uaddr, &addr_len);
> + }
> +
> + return -EAFNOSUPPORT;
> +}
> +
> +static int smbdirect_sk_connect(struct sock *sk, struct sockaddr_unsized *addr, int addr_len)
> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is already locked */
> + sock_owned_by_me(sk);
> +
> + ret = smbdirect_connect(sc, (struct sockaddr *)addr);

Why is this called via sk->sk_prot instead of being called
directly from sock->ops->connect() ?

Same for other sk->sk_prot functions.


> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sk_setsockopt(struct sock *sk, int level, int optname,
> + sockptr_t optval, unsigned int optlen)
> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + switch (level) {
> + default:
> + SMBDIRECT_FN_COMMENT(sk, "default");
> + smbdirect_log_sk(sc, SMBDIRECT_LOG_INFO,
> + "level=%d optname=%d for sk=%p\n",
> + level, optname, sk);
> + ret = -EOPNOTSUPP;
> + break;
> + }
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sk_getsockopt(struct sock *sk, int level, int optname,
> + char __user *optval, int __user *optlen)
> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + switch (level) {
> + default:
> + SMBDIRECT_FN_COMMENT(sk, "default");
> + smbdirect_log_sk(sc, SMBDIRECT_LOG_INFO,
> + "level=%d optname=%d for sk=%p\n",
> + level, optname, sk);
> + ret = -EOPNOTSUPP;
> + break;
> + }
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sk_ioctl(struct sock *sk, int cmd, int *karg)

Is there any in-kernel ioctl() user for this socket ?


> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + switch (cmd) {
> + default:
> + SMBDIRECT_FN_COMMENT(sk, "default");
> + smbdirect_log_sk(sc, SMBDIRECT_LOG_INFO,
> + "cmd=%d (0x%x) for sk=%p\n",
> + cmd, cmd, sk);
> + ret = -ENOIOCTLCMD;
> + break;
> + }
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static inline size_t smbdirect_cmsg_count(const struct msghdr *_msg,
> + int *first_sol_smbdirect_type)
> +{
> + struct msghdr *msg = (struct msghdr *)(uintptr_t)(const void *)_msg;
> + struct cmsghdr *cmsg = NULL;
> + size_t count = 0;
> +
> + if (first_sol_smbdirect_type != NULL)
> + *first_sol_smbdirect_type = -1;
> +
> + for (cmsg = CMSG_FIRSTHDR(msg);
> + cmsg != NULL;
> + cmsg = CMSG_NXTHDR(msg, cmsg)) {
> + count++;
> + if (cmsg->cmsg_level != SOL_SMBDIRECT)
> + continue;
> + if (first_sol_smbdirect_type != NULL) {
> + *first_sol_smbdirect_type = cmsg->cmsg_type;
> + first_sol_smbdirect_type = NULL;
> + }
> + }
> +
> + return count;
> +}
> +
> +static __always_inline
> +ssize_t __smbdirect_cmsg_extract(const struct msghdr *_msg,
> + int cmsg_type,
> + void *_payload,
> + size_t payloadmin,
> + size_t payloadmax)
> +{
> + struct msghdr *msg = (struct msghdr *)(uintptr_t)(const void *)_msg;
> + size_t cmsg_len_min = CMSG_LEN(payloadmin);
> + size_t cmsg_len_max = CMSG_LEN(payloadmax);
> + const size_t cmsg_len_hdr = CMSG_LEN(0);
> + uint8_t *payload = (uint8_t *)_payload;
> + struct cmsghdr *cmsg = NULL;
> + size_t payloadlen;
> +
> + BUILD_BUG_ON(cmsg_len_min > cmsg_len_max);
> + if (WARN_ON_ONCE(cmsg_len_min > cmsg_len_max))
> + return -EBADMSG;
> +
> + for (cmsg = CMSG_FIRSTHDR(msg);
> + cmsg != NULL;
> + cmsg = CMSG_NXTHDR(msg, cmsg)) {
> + if (cmsg->cmsg_level != SOL_SMBDIRECT)
> + continue;
> +
> + if (cmsg->cmsg_type != cmsg_type)
> + continue;
> +
> + if (cmsg->cmsg_len < cmsg_len_min)
> + return -EBADMSG;
> +
> + if (cmsg->cmsg_len > cmsg_len_max)
> + return -EMSGSIZE;
> +
> + payloadlen = cmsg->cmsg_len - cmsg_len_hdr;
> + if (payloadlen > 0)
> + memcpy(payload, CMSG_DATA(cmsg), payloadlen);
> + if (payloadlen < payloadmax)
> + memset(payload + payloadlen, 0, payloadmax - payloadlen);
> + return payloadlen;
> + }
> +
> + return -ENOMSG;
> +}
> +
> +static __always_inline
> +int smbdirect_buffer_remote_invalidate_cmsg_extract(const struct msghdr *msg,
> + u32 *remote_token)
> +{
> + struct smbdirect_buffer_remote_invalidate_args args = {
> + .remote_token = 0,
> + };
> + ssize_t ret;
> +
> + ret = __smbdirect_cmsg_extract(msg,
> + SMBDIRECT_BUFFER_REMOTE_INVALIDATE_CMSG_TYPE,
> + &args, sizeof(args), sizeof(args));
> + if (ret < 0)
> + return ret;
> +
> + *remote_token = args.remote_token;
> + return 0;
> +}
> +
> +static int smbdirect_sk_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t msg_len)
> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + struct iov_iter *iter = &msg->msg_iter;
> + unsigned int flags = msg->msg_flags;
> + size_t cmsg_count = 0;
> + int cmsg_type = -1;
> + bool need_invalidate = false;
> + u32 remote_key = 0;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is already locked */
> + sock_owned_by_me(sk);
> +
> + cmsg_count = smbdirect_cmsg_count(msg, &cmsg_type);
> + if (cmsg_count > 1) {
> + ret = -EINVAL;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + if (flags & ~(MSG_DONTWAIT|MSG_WAITALL|MSG_NOSIGNAL)) {
> + ret = -EINVAL;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + if (cmsg_type == SMBDIRECT_BUFFER_REMOTE_INVALIDATE_CMSG_TYPE) {
> + ret = smbdirect_buffer_remote_invalidate_cmsg_extract(msg, &remote_key);
> + if (!ret)
> + need_invalidate = true; /* remote_key is valid */
> + else if (ret != -ENOMSG) {
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> + } else if (cmsg_count) {
> + ret = -EINVAL;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + if (WARN_ON_ONCE(iov_iter_rw(iter) != ITER_SOURCE)) {
> + ret = -EINVAL;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + if (WARN_ON_ONCE(iov_iter_count(iter) != msg_len)) {
> + ret = -EINVAL;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + if (flags & MSG_DONTWAIT) {
> + if (!sc->first_error && msg_len && atomic_read(&sc->send_io.credits.count) == 0) {
> + ret = -EAGAIN;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> + }
> + flags &= ~(MSG_DONTWAIT|MSG_WAITALL|MSG_NOSIGNAL);
> +
> + ret = smbdirect_connection_send_iter(sc,
> + iter,
> + flags,
> + need_invalidate,
> + remote_key);
> + if (ret < 0)
> + /* Handle error and possibly send SIGPIPE. */
> + ret = sk_stream_error(sk, msg->msg_flags, ret);
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sk_sendmsg(struct sock *sk, struct msghdr *msg, size_t msg_len)
> +{
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + lock_sock(sk);
> + ret = smbdirect_sk_sendmsg_locked(sk, msg, msg_len);
> + release_sock(sk);
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sk_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags)
> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + struct iov_iter *iter = &msg->msg_iter;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + if (flags & ~(MSG_DONTWAIT|MSG_WAITALL|MSG_NOSIGNAL)) {
> + ret = -EINVAL;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + if (WARN_ON_ONCE(iov_iter_rw(iter) != ITER_DEST)) {
> + ret = -EINVAL;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + /*
> + * For now smbdirect_connection_recvmsg() relies
> + * on this assertion and the current in kernel
> + * users are working that way.
> + */
> + if (WARN_ON_ONCE(iov_iter_count(iter) != len)) {
> + ret = -EINVAL;
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> +
> + lock_sock(sk);
> + if (flags & MSG_DONTWAIT) {
> + if (!sc->first_error && len && sc->recv_io.reassembly.data_length == 0) {
> + ret = -EAGAIN;
> + release_sock(sk);
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> + }
> + }
> + flags &= ~(MSG_DONTWAIT|MSG_WAITALL|MSG_NOSIGNAL);
> + ret = smbdirect_connection_recvmsg(sc, msg, flags);
> + if (msg->msg_get_inq && ret >= 0)
> + msg->msg_inq = sc->recv_io.reassembly.data_length;
> + release_sock(sk);
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static void smbdirect_sk_shutdown(struct sock *sk, int how)
> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is already locked */
> + sock_owned_by_me(sk);
> +
> + smbdirect_socket_schedule_cleanup(sc, -ESHUTDOWN);
> +
> + SMBDIRECT_FN_RETURN_VOID(sk);
> +}
> +
> +static int smbdirect_sk_disconnect(struct sock *sk, int flags)
> +{
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is already locked */
> + sock_owned_by_me(sk);
> +
> + smbdirect_socket_schedule_cleanup(sc, -ESHUTDOWN);
> +
> + if (flags & O_NONBLOCK) {
> + if (sc->status >= SMBDIRECT_SOCKET_DISCONNECTED) {
> + SMBDIRECT_FN_RETURN_INT(sk, 0);
> + return 0;
> + }
> +
> + /*
> + * This will cause SS_DISCONNECTING in
> + * smbdirect_sock_connect_locked().
> + */
> + SMBDIRECT_FN_RETURN_INT(sk, sc->first_error);
> + return sc->first_error;
> + }
> +
> + smbdirect_socket_destroy_sync(sc);
> +
> + SMBDIRECT_FN_RETURN_INT(sk, 0);
> + return 0;
> +}
> +
> +static void smbdirect_sk_close(struct sock *sk, long timeout)
> +{
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + /*
> + * We hold an additional reference so
> + * that the sock_put() in sk_common_release()
> + * doesn't call sk_free(), that is potentially
> + * deferred to our sock_put() after release_sock().
> + *
> + * Note that sk_common_release() calls
> + * smbdirect_sk_destroy() as the first thing.
> + */
> + sock_hold(sk);
> + lock_sock(sk);
> + sk_common_release(sk);
> + release_sock(sk);
> + SMBDIRECT_FN_COMMENT(sk, "before sock_put()");
> + sock_put(sk);
> +}
> +
> +static struct percpu_counter smbdirect_sockets_allocated;
> +
> +static struct proto smbdirect_prot = {
> + .name = "smbdirect",
> + .owner = THIS_MODULE,
> + .obj_size = sizeof(struct smbdirect_socket),
> + .ipv6_pinfo_offset = offsetof(struct smbdirect_socket, inet6),
> + .init = smbdirect_sk_init,
> + .destroy = smbdirect_sk_destroy,
> + .hash = smbdirect_sk_hash,
> + .unhash = smbdirect_sk_unhash,
> + .release_cb = smbdirect_sk_release_cb,
> + .bind = smbdirect_sk_bind,
> + .accept = smbdirect_sk_accept,
> + .pre_connect = smbdirect_sk_pre_connect,
> + .connect = smbdirect_sk_connect,
> + .setsockopt = smbdirect_sk_setsockopt,
> + .getsockopt = smbdirect_sk_getsockopt,
> + .ioctl = smbdirect_sk_ioctl,
> + .sendmsg = smbdirect_sk_sendmsg,
> + .recvmsg = smbdirect_sk_recvmsg,
> + .shutdown = smbdirect_sk_shutdown,
> + .disconnect = smbdirect_sk_disconnect,
> + .close = smbdirect_sk_close,
> + .sockets_allocated = &smbdirect_sockets_allocated,
> +};
> +
> +static int smbdirect_sock_release(struct socket *sock)
> +{
> + struct sock *sk = sock->sk;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not locked */
> + sock_not_owned_by_me(sk);
> + WARN_ON_ONCE(sock_owned_by_user_nocheck(sk));
> +
> + switch (sk->sk_family) {
> + case AF_INET:
> + SMBDIRECT_FN_COMMENT(sk, "calling inet_release()");
> + ret = inet_release(sock);

Given setsockopt() is banned, smbdirect_sk_close() can be
inlined here.


> + break;
> + case AF_INET6:
> +#if IS_ENABLED(CONFIG_IPV6)
> + SMBDIRECT_FN_COMMENT(sk, "calling inet6_release()");
> + ret = inet6_release(sock);
> +#else
> + ret = -EAFNOSUPPORT;
> +#endif
> + break;
> + default:
> + ret = -EAFNOSUPPORT;
> + break;
> + }
> +
> + return ret;
> +}
> +
> +static int smbdirect_sock_bind(struct socket *sock, struct sockaddr_unsized *saddr, int len)
> +{
> + struct sock *sk = sock->sk;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + switch (sk->sk_family) {
> + case AF_INET:
> + ret = inet_bind(sock, saddr, len);

inet_bind() just calls sk->sk_prot->bind() if set.
So, the same question applies; why not inline
sk->sk_prot->bind() here.


> + break;
> + case AF_INET6:
> +#if IS_ENABLED(CONFIG_IPV6)
> + ret = inet6_bind(sock, saddr, len);
> +#else
> + ret = -EAFNOSUPPORT;
> +#endif
> + break;
> + default:
> + ret = -EAFNOSUPPORT;
> + break;
> + }
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sock_connect_locked(struct socket *sock,
> + struct sockaddr_unsized *uaddr,
> + int addr_len, int flags)
> +{
> + struct sock *sk = sock->sk;
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is already locked */
> + sock_owned_by_me(sk);
> +
> + if (addr_len < sizeof(uaddr->sa_family))
> + return -EINVAL;
> +
> + if (sk->sk_family != uaddr->sa_family)
> + return -EAFNOSUPPORT;
> +
> + switch (sk->sk_family) {
> + case AF_INET:
> + if (addr_len < sizeof(struct sockaddr_in))
> + return -EINVAL;
> + break;
> + case AF_INET6:
> + /*
> + * We require a full struct sockaddr_in6 (28 bytes) instead of a
> + * minimal size of SIN6_LEN_RFC2133 (24 bytes), as we don't
> + * have any legacy callers in userspace and the
> + * rdma layer also expects that.
> + */
> + if (addr_len < sizeof(struct sockaddr_in6))
> + return -EINVAL;
> + break;
> + default:
> + return -EAFNOSUPPORT;
> + }
> +
> + switch (sock->state) {
> + case SS_CONNECTED:
> + return -EISCONN;
> + case SS_CONNECTING:
> + return -EALREADY;
> + case SS_UNCONNECTED:
> + break;
> + default:
> + return -EINVAL;
> + }
> +
> + if (sc->status == SMBDIRECT_SOCKET_CONNECTED)
> + return -EISCONN;
> +
> + if (sc->status != SMBDIRECT_SOCKET_CREATED)
> + return -EINVAL;
> +
> + if (BPF_CGROUP_PRE_CONNECT_ENABLED(sk)) {
> + ret = sk->sk_prot->pre_connect(sk, uaddr, addr_len);
> + if (ret)
> + return ret;
> + }
> +
> + ret = sk->sk_prot->connect(sk, uaddr, addr_len);
> + if (ret < 0)
> + return ret;
> +
> + inet_sk_set_state(sk, TCP_SYN_SENT);
> + sock->state = SS_CONNECTING;
> +
> + if (flags & O_NONBLOCK)
> + return -EINPROGRESS;
> +
> + ret = smbdirect_connection_wait_for_connected(sc);
> + if (ret)
> + goto sock_error;
> +
> + return 0;
> +
> +sock_error:
> + sock->state = SS_UNCONNECTED;
> + sk->sk_disconnects++;
> + if (sk->sk_prot->disconnect(sk, flags))
> + sock->state = SS_DISCONNECTING;
> + return ret;
> +}
> +
> +static int smbdirect_sock_connect(struct socket *sock,
> + struct sockaddr_unsized *uaddr,
> + int addr_len, int flags)
> +{
> + struct sock *sk = sock->sk;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + lock_sock(sk);
> + ret = smbdirect_sock_connect_locked(sock, uaddr, addr_len, flags);
> + release_sock(sk);
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sock_listen(struct socket *sock, int backlog)
> +{
> + struct sock *sk = sock->sk;
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + lock_sock(sk);
> + ret = smbdirect_socket_listen(sc, backlog);
> + release_sock(sk);
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sock_accept(struct socket *lsock, struct socket *nsock,
> + struct proto_accept_arg *arg)
> +{
> + struct sock *lsk = lsock->sk;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(lsk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(lsk);
> +
> + ret = inet_accept(lsock, nsock, arg);

Could this account socket memory to memcg twice ?
see 4a997d49d92a and 16942cf4d3e3


> + if (!ret)
> + /*
> + * We want to handle all sockopts explicitly
> + * and only support what we really support.
> + */
> + set_bit(SOCK_CUSTOM_SOCKOPT, &nsock->flags);
> +
> + SMBDIRECT_FN_RETURN_INT(lsk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sock_getname(struct socket *sock, struct sockaddr *uaddr, int peer)
> +{
> + struct sock *sk = sock->sk;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + switch (sk->sk_family) {
> + case AF_INET:
> + ret = inet_getname(sock, uaddr, peer);
> + break;
> + case AF_INET6:
> +#if IS_ENABLED(CONFIG_IPV6)
> + ret = inet6_getname(sock, uaddr, peer);
> +#else
> + ret = -EAFNOSUPPORT;
> +#endif
> + break;
> + default:
> + ret = -EAFNOSUPPORT;
> + break;
> + }
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static __poll_t smbdirect_sock_poll(struct file *file, struct socket *sock, poll_table *wait)
> +{
> + struct sock *sk = sock->sk;
> + struct smbdirect_socket *sc = smbdirect_socket_from_sk(sk);
> + __poll_t mask = 0;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + sock_poll_wait(file, sock, wait);
> +
> + if (sc->status == SMBDIRECT_SOCKET_LISTENING) {
> + if (!list_empty_careful(&sc->listen.ready))
> + mask |= EPOLLIN | EPOLLRDNORM;
> + SMBDIRECT_FN_RETURN_POLL(sk, mask);
> + return mask;
> + }
> +
> + if (sc->first_error) {
> + /*
> + * A broken connection should report almost everything in order to let
> + * applications to detect it reliable.
> + */
> + mask |= EPOLLHUP;
> + mask |= EPOLLERR;
> + mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
> + mask |= EPOLLOUT | EPOLLWRNORM;
> + SMBDIRECT_FN_RETURN_POLL(sk, mask);
> + return mask;
> + }
> +
> + if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
> + /*
> + * A just created socket.
> + */
> + SMBDIRECT_FN_RETURN_POLL(sk, mask);
> + return mask;
> + }
> +
> + if (sc->recv_io.reassembly.data_length > 0)
> + mask |= EPOLLIN | EPOLLRDNORM;
> +
> + if (atomic_read(&sc->send_io.bcredits.count) > 0 &&
> + atomic_read(&sc->send_io.lcredits.count) > 0 &&
> + atomic_read(&sc->send_io.credits.count) > 0)
> + mask |= EPOLLOUT | EPOLLWRNORM;
> + else {
> + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
> + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
> +
> + /*
> + * Race breaker. If space is freed after
> + * wspace test but before the flags are set,
> + * IO signal will be lost. Memory barrier
> + * pairs with the input side.
> + */
> + smp_mb__after_atomic();
> + if (atomic_read(&sc->send_io.bcredits.count) > 0 &&
> + atomic_read(&sc->send_io.lcredits.count) > 0 &&
> + atomic_read(&sc->send_io.credits.count) > 0)
> + mask |= EPOLLOUT | EPOLLWRNORM;
> + }
> +
> + SMBDIRECT_FN_RETURN_POLL(sk, mask);
> + return mask;
> +}
> +
> +static int smbdirect_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
> +{
> + struct sock *sk = sock->sk;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + /*
> + * We may need to handle some here as
> + * smbirect_sk_ioctl() only gets a kernel
> + * int pointer as arg, but we may
> + * need to the whole struct
> + */
> + switch (cmd) {
> + default:
> + /*
> + * Note this has some special handling for
> + * sk->sk_type == SOCK_RAW, in case we ever
> + * implement SOCK_RAW...
> + *
> + * It calls smbdirect_sk_ioctl()...
> + */
> + ret = sk_ioctl(sk, cmd, (void __user *)arg);
> + break;
> + }
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sock_shutdown(struct socket *sock, int how)
> +{
> + struct sock *sk = sock->sk;
> + int ret = 0;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + /*
> + * We have these from userspace:
> + * SHUT_RD = 0, SHUT_WR = 1 and SHUT_RDWR = 2
> + *
> + * And we map them to SHUTDOWN_MASK = 3
> + * RCV_SHUTDOWN = 1, SEND_SHUTDOWN = 2, BOTH = 3
> + */
> + how++;
> + if ((how & ~SHUTDOWN_MASK) || !how) /* MAXINT->0 */
> + return -EINVAL;
> +
> + lock_sock(sk);
> +
> + switch (sk->sk_state) {
> + case TCP_CLOSE:
> + ret = -ENOTCONN;
> + fallthrough;
> + default:
> + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | how);
> + sk->sk_prot->shutdown(sk, how);
> + break;
> +
> + case TCP_SYN_SENT:
> + case TCP_SYN_RECV:
> + ret = sk->sk_prot->disconnect(sk, O_NONBLOCK);
> + break;
> + }
> +
> + /* Wake up anyone sleeping in poll. */
> + sk->sk_state_change(sk);
> + release_sock(sk);
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sock_setsockopt(struct socket *sock, int level, int optname,
> + sockptr_t optval, unsigned int optlen)
> +{
> + struct sock *sk = sock->sk;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + ret = sock_common_setsockopt(sock, level, optname, optval, optlen);
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sock_getsockopt(struct socket *sock, int level, int optname,
> + char __user *optval, int __user *optlen)
> +{
> + struct sock *sk = sock->sk;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + ret = sock_common_getsockopt(sock, level, optname, optval, optlen);
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sock_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
> +{
> + struct sock *sk = sock->sk;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + ret = sk->sk_prot->sendmsg(sk, msg, len);
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static int smbdirect_sock_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
> + int flags)
> +{
> + struct sock *sk = sock->sk;
> + int ret;
> +
> + SMBDIRECT_FN_CALLED(sk);
> +
> + /* assert it is not already locked */
> + sock_not_owned_by_me(sk);
> +
> + ret = sock_common_recvmsg(sock, msg, size, flags);
> +
> + SMBDIRECT_FN_RETURN_INT(sk, ret);
> + return ret;
> +}
> +
> +static const struct proto_ops smbdirect_inet_proto_ops = {
> + .family = PF_INET,
> + .owner = THIS_MODULE,
> + .release = smbdirect_sock_release,
> + .bind = smbdirect_sock_bind,
> + .connect = smbdirect_sock_connect,
> + .socketpair = sock_no_socketpair,
> + .listen = smbdirect_sock_listen,
> + .accept = smbdirect_sock_accept,
> + .getname = smbdirect_sock_getname,
> + .poll = smbdirect_sock_poll,
> + .ioctl = smbdirect_sock_ioctl,
> + .shutdown = smbdirect_sock_shutdown,
> + .setsockopt = smbdirect_sock_setsockopt,
> + .getsockopt = smbdirect_sock_getsockopt,
> + .sendmsg = smbdirect_sock_sendmsg,
> + .sendmsg_locked = smbdirect_sk_sendmsg_locked,
> + .recvmsg = smbdirect_sock_recvmsg,
> + .mmap = sock_no_mmap,
> +};
> +
> +#if IS_ENABLED(CONFIG_IPV6)
> +static const struct proto_ops smbdirect_inet6_proto_ops = {
> + .family = PF_INET6,
> + .owner = THIS_MODULE,
> + .release = smbdirect_sock_release,
> + .bind = smbdirect_sock_bind,
> + .connect = smbdirect_sock_connect,
> + .socketpair = sock_no_socketpair,
> + .listen = smbdirect_sock_listen,
> + .accept = smbdirect_sock_accept,
> + .getname = smbdirect_sock_getname,
> + .poll = smbdirect_sock_poll,
> + .ioctl = smbdirect_sock_ioctl,
> + .shutdown = smbdirect_sock_shutdown,
> + .setsockopt = smbdirect_sock_setsockopt,
> + .getsockopt = smbdirect_sock_getsockopt,
> + .sendmsg = smbdirect_sock_sendmsg,
> + .sendmsg_locked = smbdirect_sk_sendmsg_locked,
> + .recvmsg = smbdirect_sock_recvmsg,
> + .mmap = sock_no_mmap,
> +};
> +#endif
> +
> +static struct inet_protosw smbdirect_inet_stream_protosw = {
> + .type = SOCK_STREAM,
> + .protocol = IPPROTO_SMBDIRECT,
> + .prot = &smbdirect_prot,
> + .ops = &smbdirect_inet_proto_ops,
> +};
> +
> +#if IS_ENABLED(CONFIG_IPV6)
> +static struct inet_protosw smbdirect_inet6_stream_protosw = {
> + .type = SOCK_STREAM,
> + .protocol = IPPROTO_SMBDIRECT,
> + .prot = &smbdirect_prot,
> + .ops = &smbdirect_inet6_proto_ops,
> +};
> +#endif
> +
> +struct smbdirect_socket *smbdirect_socket_from_sock(const struct socket *sock)
> +{
> + if (!sock ||
> + !sock->sk ||
> + sock->sk->sk_protocol != IPPROTO_SMBDIRECT)
> + return NULL;
> +
> + if (WARN_ON_ONCE(sock->sk->sk_destruct != smbdirect_sk_destruct))
> + return NULL;
> +
> + return smbdirect_socket_from_sk(sock->sk);
> +}
> +__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_from_sock);
> +
> +static __init int smbdirect_protosw_init(void)
> +{
> + int err;
> +
> + err = proto_register(&smbdirect_prot, 1);
> + if (err)
> + return err;
> +
> + inet_register_protosw(&smbdirect_inet_stream_protosw);
> +#if IS_ENABLED(CONFIG_IPV6)
> + inet6_register_protosw(&smbdirect_inet6_stream_protosw);
> +#endif
> +
> + return 0;
> +}
> +
> +static __exit void smbdirect_protosw_exit(void)
> +{
> +#if IS_ENABLED(CONFIG_IPV6)
> + inet6_unregister_protosw(&smbdirect_inet6_stream_protosw);
> +#endif
> + inet_unregister_protosw(&smbdirect_inet_stream_protosw);
> +
> + proto_unregister(&smbdirect_prot);
> +}
> +
> +__init int smbdirect_proto_init(void)
> +{
> + int err;
> +
> + err = percpu_counter_init(&smbdirect_sockets_allocated, 0, GFP_KERNEL);
> + if (err)
> + goto err_percpu_counter;
> +
> + err = smbdirect_protosw_init();
> + if (err)
> + goto err_protosw;
> +
> + return 0;
> +
> +err_protosw:
> + percpu_counter_destroy(&smbdirect_sockets_allocated);
> +err_percpu_counter:
> + return err;
> +}
> +
> +__exit void smbdirect_proto_exit(void)
> +{
> + smbdirect_protosw_exit();
> + percpu_counter_destroy(&smbdirect_sockets_allocated);
> +}
> +
> +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_INET, 257 /* IPPROTO_SMBDIRECT */, SOCK_STREAM);
> +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_INET6, 257 /* IPPROTO_SMBDIRECT */, SOCK_STREAM);
> diff --git a/fs/smb/common/smbdirect/smbdirect_public.h b/fs/smb/common/smbdirect/smbdirect_public.h
> index 50088155e7c3..9f96c66bbe32 100644
> --- a/fs/smb/common/smbdirect/smbdirect_public.h
> +++ b/fs/smb/common/smbdirect/smbdirect_public.h
> @@ -49,6 +49,7 @@ int smbdirect_socket_set_kernel_settings(struct smbdirect_socket *sc,
> #define SMBDIRECT_LOG_RDMA_MR 0x100
> #define SMBDIRECT_LOG_RDMA_RW 0x200
> #define SMBDIRECT_LOG_NEGOTIATE 0x400
> +#define SMBDIRECT_LOG_SK 0x800
> void smbdirect_socket_set_logging(struct smbdirect_socket *sc,
> void *private_ptr,
> bool (*needed)(struct smbdirect_socket *sc,
> @@ -145,4 +146,6 @@ void smbdirect_connection_legacy_debug_proc_show(struct smbdirect_socket *sc,
> unsigned int rdma_readwrite_threshold,
> struct seq_file *m);
>
> +struct smbdirect_socket *smbdirect_socket_from_sock(const struct socket *sock);
> +
> #endif /* __FS_SMB_COMMON_SMBDIRECT_SMBDIRECT_PUBLIC_H__ */
> diff --git a/fs/smb/common/smbdirect/smbdirect_rw.c b/fs/smb/common/smbdirect/smbdirect_rw.c
> index 3b2eb8c48efc..154339955617 100644
> --- a/fs/smb/common/smbdirect/smbdirect_rw.c
> +++ b/fs/smb/common/smbdirect/smbdirect_rw.c
> @@ -105,11 +105,11 @@ static void smbdirect_connection_rdma_write_done(struct ib_cq *cq, struct ib_wc
> smbdirect_connection_rdma_rw_done(cq, wc, DMA_TO_DEVICE);
> }
>
> -int smbdirect_connection_rdma_xmit(struct smbdirect_socket *sc,
> - void *buf, size_t buf_len,
> - struct smbdirect_buffer_descriptor_v1 *desc,
> - size_t desc_len,
> - bool is_read)
> +static int smbdirect_connection_rdma_xmit_locked(struct smbdirect_socket *sc,
> + void *buf, size_t buf_len,
> + struct smbdirect_buffer_descriptor_v1 *desc,
> + size_t desc_len,
> + bool is_read)
> {
> const struct smbdirect_socket_parameters *sp = &sc->parameters;
> enum dma_data_direction direction = is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE;
> @@ -123,6 +123,8 @@ int smbdirect_connection_rdma_xmit(struct smbdirect_socket *sc,
> int credits_needed;
> size_t desc_buf_len, desc_num = 0;
>
> + smbdirect_socket_sk_owned_by_me(sc);
> +
> if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
> return -ENOTCONN;
>
> @@ -235,7 +237,9 @@ int smbdirect_connection_rdma_xmit(struct smbdirect_socket *sc,
> }
>
> msg = list_last_entry(&msg_list, struct smbdirect_rw_io, list);
> + smbdirect_socket_sk_unlock(sc);
> wait_for_completion(&completion);
> + smbdirect_socket_sk_lock(sc);
> ret = msg->error;
> out:
> list_for_each_entry_safe(msg, next_msg, &msg_list, list) {
> @@ -252,4 +256,19 @@ int smbdirect_connection_rdma_xmit(struct smbdirect_socket *sc,
> kfree(msg);
> goto out;
> }
> +
> +int smbdirect_connection_rdma_xmit(struct smbdirect_socket *sc,
> + void *buf, size_t buf_len,
> + struct smbdirect_buffer_descriptor_v1 *desc,
> + size_t desc_len,
> + bool is_read)
> +{
> + int ret;
> +
> + smbdirect_socket_sk_lock(sc);
> + ret = smbdirect_connection_rdma_xmit_locked(sc, buf, buf_len, desc, desc_len, is_read);
> + smbdirect_socket_sk_unlock(sc);
> +
> + return ret;
> +}
> __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_rdma_xmit);
> diff --git a/fs/smb/common/smbdirect/smbdirect_socket.c b/fs/smb/common/smbdirect/smbdirect_socket.c
> index 9153e1dbf53d..76e406999588 100644
> --- a/fs/smb/common/smbdirect/smbdirect_socket.c
> +++ b/fs/smb/common/smbdirect/smbdirect_socket.c
> @@ -5,6 +5,7 @@
> */
>
> #include "smbdirect_internal.h"
> +#include <net/transp_v6.h>
>
> bool smbdirect_frwr_is_supported(const struct ib_device_attr *attrs)
> {
> @@ -217,6 +218,7 @@ int smbdirect_socket_set_kernel_settings(struct smbdirect_socket *sc,
> sc->send_io.mem.gfp_mask = gfp_mask;
> sc->recv_io.mem.gfp_mask = gfp_mask;
> sc->rw_io.mem.gfp_mask = gfp_mask;
> + sc->sk.sk_allocation = gfp_mask;
>
> return 0;
> }
> @@ -242,6 +244,106 @@ void smbdirect_socket_set_logging(struct smbdirect_socket *sc,
> }
> __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_set_logging);
>
> +int smbdirect_socket_sync_saddr_to_sk(struct smbdirect_socket *sc, bool *_is_any_addr)
> +{
> + struct sock *sk = &sc->sk;
> + const struct sockaddr_storage *saddr;
> + const struct sockaddr_in *sin;
> + const struct sockaddr_in6 *sin6;
> + struct in_addr sin_addr = { .s_addr = htonl(INADDR_ANY), };
> + struct in6_addr sin6_addr = in6addr_any;
> + __be32 sin6_flowinfo = 0;
> + bool is_any_addr = true;
> + u16 sport = 0;
> + int ret;
> +
> + saddr = &sc->rdma.cm_id->route.addr.src_addr;
> +
> + if (WARN_ON_ONCE(saddr->ss_family != sk->sk_family)) {
> + ret = -EINVAL;
> + return ret;
> + }
> +
> + switch (saddr->ss_family) {
> + case AF_INET:
> + sin = (struct sockaddr_in *)saddr;
> + sport = ntohs(sin->sin_port);
> + sin_addr = sin->sin_addr;
> + is_any_addr = (sin_addr.s_addr == htonl(INADDR_ANY));
> + break;
> +
> + case AF_INET6:
> + sin6 = (struct sockaddr_in6 *)saddr;
> + sport = ntohs(sin6->sin6_port);
> + sin_addr.s_addr = LOOPBACK4_IPV6;
> + sin6_addr = sin6->sin6_addr;
> + is_any_addr = ipv6_addr_any(&sin6_addr);
> + sin6_flowinfo = sin6->sin6_flowinfo;
> + break;
> + }
> +
> + sk->sk_bound_dev_if = sc->rdma.cm_id->route.addr.dev_addr.bound_dev_if;
> + sk->sk_rcv_saddr = sc->inet.inet_saddr = sin_addr.s_addr;
> +#if IS_ENABLED(CONFIG_IPV6)
> + sk->sk_v6_rcv_saddr = sc->inet6.saddr = sin6_addr;
> +#else
> + sc->inet6.saddr = sin6_addr;
> +#endif
> + sc->inet6.flow_label = sin6_flowinfo;
> + sk->sk_num = sport;
> + sc->inet.inet_sport = htons(sport);
> +
> + if (_is_any_addr)
> + *_is_any_addr = is_any_addr;
> + return 0;
> +}
> +
> +int smbdirect_socket_sync_daddr_to_sk(struct smbdirect_socket *sc)
> +{
> + struct sock *sk = &sc->sk;
> + const struct sockaddr_storage *daddr;
> + const struct sockaddr_in *sin;
> + const struct sockaddr_in6 *sin6;
> + struct in_addr sin_addr = { .s_addr = htonl(INADDR_ANY), };
> +#if IS_ENABLED(CONFIG_IPV6)
> + struct in6_addr sin6_addr = in6addr_any;
> +#endif
> + u16 dport = 0;
> + int ret;
> +
> + daddr = &sc->rdma.cm_id->route.addr.dst_addr;
> +
> + if (WARN_ON_ONCE(daddr->ss_family != sk->sk_family)) {
> + ret = -EINVAL;
> + return ret;
> + }
> +
> + switch (daddr->ss_family) {
> + case AF_INET:
> + sin = (struct sockaddr_in *)daddr;
> + dport = ntohs(sin->sin_port);
> + sin_addr = sin->sin_addr;
> + break;
> +
> + case AF_INET6:
> + sin6 = (struct sockaddr_in6 *)daddr;
> + dport = ntohs(sin6->sin6_port);
> + sin_addr.s_addr = LOOPBACK4_IPV6;
> +#if IS_ENABLED(CONFIG_IPV6)
> + sin6_addr = sin6->sin6_addr;
> +#endif
> + break;
> + }
> +
> + sk->sk_daddr = sc->inet.inet_daddr = sin_addr.s_addr;
> +#if IS_ENABLED(CONFIG_IPV6)
> + sk->sk_v6_daddr = sin6_addr;
> +#endif
> + sk->sk_dport = sc->inet.inet_dport = htons(dport);
> +
> + return 0;
> +}
> +
> static void smbdirect_socket_wake_up_all(struct smbdirect_socket *sc)
> {
> /*
> @@ -257,6 +359,38 @@ static void smbdirect_socket_wake_up_all(struct smbdirect_socket *sc)
> wake_up_all(&sc->recv_io.reassembly.wait_queue);
> wake_up_all(&sc->rw_io.credits.wait_queue);
> wake_up_all(&sc->mr_io.ready.wait_queue);
> +
> + if (sc->sk.sk_family) {
> + struct sock *sk = &sc->sk;
> +
> + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK);
> +
> + WARN_ON_ONCE(sc->first_error == 0);
> + if (sc->first_error < 0)
> + WRITE_ONCE(sk->sk_err, -sc->first_error);
> + else
> + WRITE_ONCE(sk->sk_err, sc->first_error);
> +
> + if (sc->status >= SMBDIRECT_SOCKET_DISCONNECTED) {
> + inet_sk_set_state(sk, TCP_CLOSE);
> + if (!sock_flag(sk, SOCK_DEAD) && sk->sk_socket)
> + sk->sk_socket->state = SS_UNCONNECTED;
> + } else {
> + inet_sk_set_state(sk, TCP_CLOSING);
> + if (!sock_flag(sk, SOCK_DEAD) && sk->sk_socket)
> + sk->sk_socket->state = SS_DISCONNECTING;
> + }
> +
> + /*
> + * Note tcp_done_with_error() also calls both
> + * sk->sk_state_change(sk) via tcp_done()
> + * and sk_error_report() directly.
> + */
> + if (!sock_flag(sk, SOCK_DEAD) && sk->sk_socket)
> + sk->sk_state_change(sk);
> + if (!sock_flag(sk, SOCK_DEAD) && sk->sk_socket)
> + sk_error_report(sk);
> + }
> }
>
> void __smbdirect_socket_schedule_cleanup(struct smbdirect_socket *sc,
> @@ -510,11 +644,13 @@ static void smbdirect_socket_destroy(struct smbdirect_socket *sc)
> */
> smbdirect_socket_wake_up_all(sc);
>
> + smbdirect_socket_sk_unlock(sc);
> disable_work_sync(&sc->disconnect_work);
> disable_work_sync(&sc->connect.work);
> disable_work_sync(&sc->recv_io.posted.refill_work);
> disable_work_sync(&sc->idle.immediate_work);
> disable_delayed_work_sync(&sc->idle.timer_work);
> + smbdirect_socket_sk_lock(sc);
>
> if (sc->rdma.cm_id)
> rdma_lock_handler(sc->rdma.cm_id);
> @@ -600,6 +736,8 @@ void smbdirect_socket_destroy_sync(struct smbdirect_socket *sc)
> */
> WARN_ON_ONCE(in_interrupt());
>
> + smbdirect_socket_sk_owned_by_me(sc);
> +
> /*
> * First we try to disable the work
> * without disable_work_sync() in a
> @@ -625,7 +763,9 @@ void smbdirect_socket_destroy_sync(struct smbdirect_socket *sc)
>
> smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
> "cancelling and disable disconnect_work\n");
> + smbdirect_socket_sk_unlock(sc);
> disable_work_sync(&sc->disconnect_work);
> + smbdirect_socket_sk_lock(sc);
>
> smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
> "destroying rdma session\n");
> @@ -634,7 +774,9 @@ void smbdirect_socket_destroy_sync(struct smbdirect_socket *sc)
> if (sc->status < SMBDIRECT_SOCKET_DISCONNECTED) {
> smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
> "wait for transport being disconnected\n");
> + smbdirect_socket_sk_unlock(sc);
> wait_event(sc->status_wait, sc->status == SMBDIRECT_SOCKET_DISCONNECTED);
> + smbdirect_socket_sk_lock(sc);
> smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
> "waited for transport being disconnected\n");
> }
> @@ -723,6 +865,8 @@ int smbdirect_socket_wait_for_credits(struct smbdirect_socket *sc,
> {
> int ret;
>
> + smbdirect_socket_sk_owned_by_me(sc);
> +
> if (WARN_ON_ONCE(needed < 0))
> return -EINVAL;
>
> @@ -731,9 +875,12 @@ int smbdirect_socket_wait_for_credits(struct smbdirect_socket *sc,
> return 0;
>
> atomic_add(needed, total_credits);
> +
> + smbdirect_socket_sk_unlock(sc);
> ret = wait_event_interruptible(*waitq,
> atomic_read(total_credits) >= needed ||
> sc->status != expected_status);
> + smbdirect_socket_sk_lock(sc);
>
> if (sc->status != expected_status)
> return unexpected_errno;
> diff --git a/fs/smb/common/smbdirect/smbdirect_socket.h b/fs/smb/common/smbdirect/smbdirect_socket.h
> index c09eddd8ad16..6bb201683259 100644
> --- a/fs/smb/common/smbdirect/smbdirect_socket.h
> +++ b/fs/smb/common/smbdirect/smbdirect_socket.h
> @@ -104,6 +104,18 @@ enum smbdirect_keepalive_status {
> };
>
> struct smbdirect_socket {
> + union {
> + struct sock sk;
> + struct inet_sock inet;
> + };
> + /* needed by inet6_create() */
> + struct ipv6_pinfo inet6;
> + void (*orig_sk_destruct)(struct sock *sk);
> +
> + /*
> + * This is the first element that is
> + * initialized in smbdirect_socket_init()
> + */
> enum smbdirect_socket_status status;
> wait_queue_head_t status_wait;
> int first_error;
> @@ -548,14 +560,18 @@ static void __smbdirect_log_printf(struct smbdirect_socket *sc,
> __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_RDMA_RW, fmt, ##args)
> #define smbdirect_log_negotiate(sc, lvl, fmt, args...) \
> __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_NEGOTIATE, fmt, ##args)
> +#define smbdirect_log_sk(sc, lvl, fmt, args...) \
> + __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_SK, fmt, ##args)
>
> static __always_inline void smbdirect_socket_init(struct smbdirect_socket *sc)
> {
> + const size_t status_offset = offsetof(struct smbdirect_socket, status);
> +
> /*
> * This also sets status = SMBDIRECT_SOCKET_CREATED
> */
> BUILD_BUG_ON(SMBDIRECT_SOCKET_CREATED != 0);
> - memset(sc, 0, sizeof(*sc));
> + memset(((u8 *)sc)+status_offset, 0, sizeof(*sc)-status_offset);
>
> init_waitqueue_head(&sc->status_wait);
>
> @@ -700,6 +716,14 @@ static __always_inline void smbdirect_socket_init(struct smbdirect_socket *sc)
> __SMBDIRECT_CHECK_STATUS_WARN(__sc, __expected_status, \
> __SMBDIRECT_SOCKET_DISCONNECT(__sc);)
>
> +static __always_inline struct smbdirect_socket *
> +smbdirect_socket_from_sk(const struct sock *sk)
> +{
> + WARN_ON_ONCE(!sk);
> + BUILD_BUG_ON(offsetof(struct smbdirect_socket, sk) != 0);
> + return container_of(sk, struct smbdirect_socket, sk);
> +}
> +
> struct smbdirect_send_io {
> struct smbdirect_socket *socket;
> struct ib_cqe cqe;




> --
> 2.43.0
>