[PATCH net-next v2 1/2] net: vxlan: enable local address bind for vxlan sockets

From: Richard Gobert
Date: Mon Jul 08 2024 - 07:13:04 EST


This patch adds support for binding to a local address in vxlan sockets.
It achieves this by using vxlan_addr union to represent a local address
to bind to, and copying it to udp_port_cfg in vxlan_create_sock.

Also change vxlan_find_sock to search the socket based on the listening address.

Signed-off-by: Richard Gobert <richardbgobert@xxxxxxxxx>
---
drivers/net/vxlan/vxlan_core.c | 53 ++++++++++++++++++++++++----------
1 file changed, 38 insertions(+), 15 deletions(-)

diff --git a/drivers/net/vxlan/vxlan_core.c b/drivers/net/vxlan/vxlan_core.c
index ba59e92ab941..9a797147beb7 100644
--- a/drivers/net/vxlan/vxlan_core.c
+++ b/drivers/net/vxlan/vxlan_core.c
@@ -72,22 +72,34 @@ static inline bool vxlan_collect_metadata(struct vxlan_sock *vs)
}

/* Find VXLAN socket based on network namespace, address family, UDP port,
- * enabled unshareable flags and socket device binding (see l3mdev with
- * non-default VRF).
+ * bounded address, enabled unshareable flags and socket device binding
+ * (see l3mdev with non-default VRF).
*/
static struct vxlan_sock *vxlan_find_sock(struct net *net, sa_family_t family,
- __be16 port, u32 flags, int ifindex)
+ __be16 port, u32 flags, int ifindex, union vxlan_addr *saddr)
{
struct vxlan_sock *vs;

flags &= VXLAN_F_RCV_FLAGS;

hlist_for_each_entry_rcu(vs, vs_head(net, port), hlist) {
- if (inet_sk(vs->sock->sk)->inet_sport == port &&
+ struct sock *sk = vs->sock->sk;
+ struct inet_sock *inet = inet_sk(sk);
+
+ if (inet->inet_sport == port &&
vxlan_get_sk_family(vs) == family &&
vs->flags == flags &&
- vs->sock->sk->sk_bound_dev_if == ifindex)
- return vs;
+ vs->sock->sk->sk_bound_dev_if == ifindex) {
+ if (family == AF_INET && inet->inet_rcv_saddr == saddr->sin.sin_addr.s_addr) {
+ return vs;
+ }
+#if IS_ENABLED(CONFIG_IPV6)
+ else if (ipv6_addr_cmp(&sk->sk_v6_rcv_saddr, &saddr->sin6.sin6_addr) == 0)
+ return vs;
+ }
+#endif
+ }
+
}
return NULL;
}
@@ -135,11 +147,11 @@ static struct vxlan_dev *vxlan_vs_find_vni(struct vxlan_sock *vs,
/* Look up VNI in a per net namespace table */
static struct vxlan_dev *vxlan_find_vni(struct net *net, int ifindex,
__be32 vni, sa_family_t family,
- __be16 port, u32 flags)
+ __be16 port, u32 flags, union vxlan_addr *saddr)
{
struct vxlan_sock *vs;

- vs = vxlan_find_sock(net, family, port, flags, ifindex);
+ vs = vxlan_find_sock(net, family, port, flags, ifindex, saddr);
if (!vs)
return NULL;

@@ -2315,7 +2327,7 @@ static int encap_bypass_if_local(struct sk_buff *skb, struct net_device *dev,
dst_release(dst);
dst_vxlan = vxlan_find_vni(vxlan->net, dst_ifindex, vni,
addr_family, dst_port,
- vxlan->cfg.flags);
+ vxlan->cfg.flags, &vxlan->cfg.saddr);
if (!dst_vxlan) {
DEV_STATS_INC(dev, tx_errors);
vxlan_vnifilter_count(vxlan, vni, NULL,
@@ -3503,8 +3515,9 @@ static const struct ethtool_ops vxlan_ethtool_ops = {
.get_link_ksettings = vxlan_get_link_ksettings,
};

-static struct socket *vxlan_create_sock(struct net *net, bool ipv6,
- __be16 port, u32 flags, int ifindex)
+static struct socket *vxlan_create_sock(struct net *net, bool ipv6, __be16 port,
+ u32 flags, int ifindex,
+ union vxlan_addr *addr)
{
struct socket *sock;
struct udp_port_cfg udp_conf;
@@ -3517,8 +3530,17 @@ static struct socket *vxlan_create_sock(struct net *net, bool ipv6,
udp_conf.use_udp6_rx_checksums =
!(flags & VXLAN_F_UDP_ZERO_CSUM6_RX);
udp_conf.ipv6_v6only = 1;
+#if IS_ENABLED(CONFIG_IPV6)
+ memcpy(&udp_conf.local_ip6.s6_addr32,
+ &addr->sin6.sin6_addr.s6_addr32,
+ sizeof(addr->sin6.sin6_addr.s6_addr32));
+#endif
} else {
udp_conf.family = AF_INET;
+ udp_conf.local_ip.s_addr = addr->sin.sin_addr.s_addr;
+ memcpy(&udp_conf.local_ip.s_addr,
+ &addr->sin.sin_addr.s_addr,
+ sizeof(addr->sin.sin_addr.s_addr));
}

udp_conf.local_udp_port = port;
@@ -3536,7 +3558,8 @@ static struct socket *vxlan_create_sock(struct net *net, bool ipv6,
/* Create new listen socket if needed */
static struct vxlan_sock *vxlan_socket_create(struct net *net, bool ipv6,
__be16 port, u32 flags,
- int ifindex)
+ int ifindex,
+ union vxlan_addr *addr)
{
struct vxlan_net *vn = net_generic(net, vxlan_net_id);
struct vxlan_sock *vs;
@@ -3551,7 +3574,7 @@ static struct vxlan_sock *vxlan_socket_create(struct net *net, bool ipv6,
for (h = 0; h < VNI_HASH_SIZE; ++h)
INIT_HLIST_HEAD(&vs->vni_list[h]);

- sock = vxlan_create_sock(net, ipv6, port, flags, ifindex);
+ sock = vxlan_create_sock(net, ipv6, port, flags, ifindex, addr);
if (IS_ERR(sock)) {
kfree(vs);
return ERR_CAST(sock);
@@ -3605,7 +3628,7 @@ static int __vxlan_sock_add(struct vxlan_dev *vxlan, bool ipv6)
spin_lock(&vn->sock_lock);
vs = vxlan_find_sock(vxlan->net, ipv6 ? AF_INET6 : AF_INET,
vxlan->cfg.dst_port, vxlan->cfg.flags,
- l3mdev_index);
+ l3mdev_index, &vxlan->cfg.saddr);
if (vs && !refcount_inc_not_zero(&vs->refcnt)) {
spin_unlock(&vn->sock_lock);
return -EBUSY;
@@ -3615,7 +3638,7 @@ static int __vxlan_sock_add(struct vxlan_dev *vxlan, bool ipv6)
if (!vs)
vs = vxlan_socket_create(vxlan->net, ipv6,
vxlan->cfg.dst_port, vxlan->cfg.flags,
- l3mdev_index);
+ l3mdev_index, &vxlan->cfg.saddr);
if (IS_ERR(vs))
return PTR_ERR(vs);
#if IS_ENABLED(CONFIG_IPV6)
--
2.36.1