Re: [PATCH v1 3/3] net: simplify sk_page_frag

From: Paolo Abeni
Date: Fri Dec 09 2022 - 11:56:00 EST


On Mon, 2022-11-21 at 08:35 -0500, Benjamin Coddington wrote:
> Now that in-kernel socket users that may recurse during reclaim have benn
> converted to sk_use_task_frag = false, we can have sk_page_frag() simply
> check that value.
>
> Signed-off-by: Benjamin Coddington <bcodding@xxxxxxxxxx>
> ---
> include/net/sock.h | 9 ++-------
> 1 file changed, 2 insertions(+), 7 deletions(-)
>
> diff --git a/include/net/sock.h b/include/net/sock.h
> index ffba9e95470d..fac24c6ee30d 100644
> --- a/include/net/sock.h
> +++ b/include/net/sock.h
> @@ -2539,19 +2539,14 @@ static inline void sk_stream_moderate_sndbuf(struct sock *sk)
> * Both direct reclaim and page faults can nest inside other
> * socket operations and end up recursing into sk_page_frag()
> * while it's already in use: explicitly avoid task page_frag
> - * usage if the caller is potentially doing any of them.
> - * This assumes that page fault handlers use the GFP_NOFS flags or
> - * explicitly disable sk_use_task_frag.
> + * when users disable sk_use_task_frag.
> *
> * Return: a per task page_frag if context allows that,
> * otherwise a per socket one.
> */
> static inline struct page_frag *sk_page_frag(struct sock *sk)
> {
> - if (sk->sk_use_task_frag &&
> - (sk->sk_allocation & (__GFP_DIRECT_RECLAIM | __GFP_MEMALLOC |
> - __GFP_FS)) ==
> - (__GFP_DIRECT_RECLAIM | __GFP_FS))
> + if (sk->sk_use_task_frag)
> return &current->task_frag;
>
> return &sk->sk_frag;

To make the above as safe as possible I think we should double-check
the in-kernel users explicitly setting sk_allocation to GFP_ATOMIC, as
that has the side effect of disabling the task_frag usage, too.

Patch 2/3 already catches some of such users, and we can safely leave
alone few others, (specifically l2tp, fou and inet_ctl_sock_create()).

Even wireguard and tls looks safe IMHO.

So the only left-over should be espintcp, I suggest updating patch 2/3
clearing sk_use_task_frag even in espintcp_init_sk().

Other than that LGTM.

Cheers,

Paolo