Re: [RFC PATCH v1 07/12] vsock/virtio: MGS_ZEROCOPY flag support

From: Krasnov Arseniy
Date: Mon Feb 20 2023 - 04:04:17 EST


On 16.02.2023 18:16, Stefano Garzarella wrote:
> On Mon, Feb 06, 2023 at 07:00:35AM +0000, Arseniy Krasnov wrote:
>> This adds main logic of MSG_ZEROCOPY flag processing for packet
>> creation. When this flag is set and user's iov iterator fits for
>> zerocopy transmission, call 'get_user_pages()' and add returned
>> pages to the newly created skb.
>>
>> Signed-off-by: Arseniy Krasnov <AVKrasnov@xxxxxxxxxxxxxx>
>> ---
>> net/vmw_vsock/virtio_transport_common.c | 212 ++++++++++++++++++++++--
>> 1 file changed, 195 insertions(+), 17 deletions(-)
>>
>> diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
>> index 05ce97b967ad..69e37f8a68a6 100644
>> --- a/net/vmw_vsock/virtio_transport_common.c
>> +++ b/net/vmw_vsock/virtio_transport_common.c
>> @@ -37,6 +37,169 @@ virtio_transport_get_ops(struct vsock_sock *vsk)
>>     return container_of(t, struct virtio_transport, transport);
>> }
>>
>
> I'd use bool if we don't need to return an error value in the following
> new functions.
>
>> +static int virtio_transport_can_zcopy(struct iov_iter *iov_iter,
>> +                      size_t free_space)
>> +{
>> +    size_t pages;
>> +    int i;
>> +
>> +    if (!iter_is_iovec(iov_iter))
>> +        return -1;
>> +
>> +    if (iov_iter->iov_offset)
>> +        return -1;
>> +
>> +    /* We can't send whole iov. */
>> +    if (free_space < iov_iter->count)
>> +        return -1;
>> +
>> +    for (pages = 0, i = 0; i < iov_iter->nr_segs; i++) {
>> +        const struct iovec *iovec;
>> +        int pages_in_elem;
>> +
>> +        iovec = &iov_iter->iov[i];
>> +
>> +        /* Base must be page aligned. */
>> +        if (offset_in_page(iovec->iov_base))
>> +            return -1;
>> +
>> +        /* Only last element could have not page aligned size.  */
>> +        if (i != (iov_iter->nr_segs - 1)) {
>> +            if (offset_in_page(iovec->iov_len))
>> +                return -1;
>> +
>> +            pages_in_elem = iovec->iov_len >> PAGE_SHIFT;
>> +        } else {
>> +            pages_in_elem = round_up(iovec->iov_len, PAGE_SIZE);
>> +            pages_in_elem >>= PAGE_SHIFT;
>> +        }
>> +
>> +        /* In case of user's pages - one page is one frag. */
>> +        if (pages + pages_in_elem > MAX_SKB_FRAGS)
>> +            return -1;
>> +
>> +        pages += pages_in_elem;
>> +    }
>> +
>> +    return 0;
>> +}
>> +
>> +static int virtio_transport_init_zcopy_skb(struct vsock_sock *vsk,
>> +                       struct sk_buff *skb,
>> +                       struct iov_iter *iter,
>> +                       bool zerocopy)
>> +{
>> +    struct ubuf_info_msgzc *uarg_zc;
>> +    struct ubuf_info *uarg;
>> +
>> +    uarg = msg_zerocopy_realloc(sk_vsock(vsk),
>> +                    iov_length(iter->iov, iter->nr_segs),
>> +                    NULL);
>> +
>> +    if (!uarg)
>> +        return -1;
>> +
>> +    uarg_zc = uarg_to_msgzc(uarg);
>> +    uarg_zc->zerocopy = zerocopy ? 1 : 0;
>> +
>> +    skb_zcopy_init(skb, uarg);
>> +
>> +    return 0;
>> +}
>> +
>> +static int virtio_transport_fill_nonlinear_skb(struct sk_buff *skb,
>> +                           struct vsock_sock *vsk,
>> +                           struct virtio_vsock_pkt_info *info)
>> +{
>> +    struct iov_iter *iter;
>> +    int frag_idx;
>> +    int seg_idx;
>> +
>> +    iter = &info->msg->msg_iter;
>> +    frag_idx = 0;
>> +    VIRTIO_VSOCK_SKB_CB(skb)->curr_frag = 0;
>> +    VIRTIO_VSOCK_SKB_CB(skb)->frag_off = 0;
>> +
>> +    /* At this moment:
>> +     * 1) 'iov_offset' is zero.
>> +     * 2) Every 'iov_base' and 'iov_len' are also page aligned
>> +     *    (except length of the last element).
>> +     * 3) Number of pages in this iov <= MAX_SKB_FRAGS.
>> +     * 4) Length of the data fits in current credit space.
>> +     */
>> +    for (seg_idx = 0; seg_idx < iter->nr_segs; seg_idx++) {
>> +        struct page *user_pages[MAX_SKB_FRAGS];
>> +        const struct iovec *iovec;
>> +        size_t last_frag_len;
>> +        size_t pages_in_seg;
>> +        int page_idx;
>> +
>> +        iovec = &iter->iov[seg_idx];
>> +        pages_in_seg = iovec->iov_len >> PAGE_SHIFT;
>> +
>> +        if (iovec->iov_len % PAGE_SIZE) {
>> +            last_frag_len = iovec->iov_len % PAGE_SIZE;
>> +            pages_in_seg++;
>> +        } else {
>> +            last_frag_len = PAGE_SIZE;
>> +        }
>> +
>> +        if (get_user_pages((unsigned long)iovec->iov_base,
>> +                   pages_in_seg, FOLL_GET, user_pages,
>> +                   NULL) != pages_in_seg)
>> +            return -1;
>
> Reading the get_user_pages() documentation, this should pin the user
> pages, so we should be fine if we then expose them in the virtqueue.
>
> But reading Documentation/core-api/pin_user_pages.rst it seems that
> drivers should use "pin_user_pages*() for DMA-pinned pages", so I'm not
> sure what we should do.
>
That is really interesting question for me too. IIUC 'pin_user_pages()'
sets special value to ref counter of page, so we can distinguish such
pages from the others. I've grepped for pinned pages check and found,
the it is used in mm/vmscan.c by calling 'folio_maybe_dma_pinned()' during
page lists processing. Seems 'pin_user_pages()' is more strict version of
'get_user_pages()' and it is recommended to use 'pin_' when data on these
pages will be accessed.
I think, i'll check which API is used in the TCP implementation for zerocopy
transmission.

> Additional advice would be great!
>
> Anyway, when we are done using the pages, we should call put_page() or
> unpin_user_page() depending on how we pin them.
>
In case of 'get_user_pages()' everything is ok here: when such skb
will be released, 'put_page()' will be called for every frag page
of it, so there is no page leak. But in case of 'pin_user_pages()',
i will need to unpin in manually before calling 'consume_skb()'
after it is processed by virtio device. But anyway - it is not a
problem.
>> +
>> +        for (page_idx = 0; page_idx < pages_in_seg; page_idx++) {
>> +            int frag_len = PAGE_SIZE;
>> +
>> +            if (page_idx == (pages_in_seg - 1))
>> +                frag_len = last_frag_len;
>> +
>> +            skb_fill_page_desc(skb, frag_idx++,
>> +                       user_pages[page_idx], 0,
>> +                       frag_len);
>> +            skb_len_add(skb, frag_len);
>> +        }
>> +    }
>> +
>> +    return virtio_transport_init_zcopy_skb(vsk, skb, iter, true);
>> +}
>> +
>> +static int virtio_transport_copy_payload(struct sk_buff *skb,
>> +                     struct vsock_sock *vsk,
>> +                     struct virtio_vsock_pkt_info *info,
>> +                     size_t len)
>> +{
>> +    void *payload;
>> +    int err;
>> +
>> +    payload = skb_put(skb, len);
>> +    err = memcpy_from_msg(payload, info->msg, len);
>> +    if (err)
>> +        return -1;
>> +
>> +    if (msg_data_left(info->msg))
>> +        return 0;
>> +
>> +    if (info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
>> +        struct virtio_vsock_hdr *hdr;
>> +
>> +        hdr = virtio_vsock_hdr(skb);
>> +
>> +        hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
>> +
>> +        if (info->msg->msg_flags & MSG_EOR)
>> +            hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
>> +    }
>> +
>
> A comment here explaining why this is necessary would be helpful.
>
>> +    if (info->flags & MSG_ZEROCOPY)
>> +        return virtio_transport_init_zcopy_skb(vsk, skb,
>> +                               &info->msg->msg_iter,
>> +                               false);
>> +
>> +    return 0;
>> +}
>> +
>> /* Returns a new packet on success, otherwise returns NULL.
>>  *
>>  * If NULL is returned, errp is set to a negative errno.
>> @@ -47,15 +210,31 @@ virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
>>                u32 src_cid,
>>                u32 src_port,
>>                u32 dst_cid,
>> -               u32 dst_port)
>> +               u32 dst_port,
>> +               struct vsock_sock *vsk)
>> {
>> -    const size_t skb_len = VIRTIO_VSOCK_SKB_HEADROOM + len;
>> +    const size_t skb_len = VIRTIO_VSOCK_SKB_HEADROOM;
>>     struct virtio_vsock_hdr *hdr;
>>     struct sk_buff *skb;
>> -    void *payload;
>> -    int err;
>> +    bool use_zcopy = false;
>> +
>> +    if (info->msg) {
>> +        /* If SOCK_ZEROCOPY is not enabled, ignore MSG_ZEROCOPY
>> +         * flag later and continue in classic way(e.g. without
>> +         * completion).
>> +         */
>> +        if (!sock_flag(sk_vsock(vsk), SOCK_ZEROCOPY)) {
>
> `vsk` can be null, should we check it?
> Otherwise, what about passing only a flag?
> So the caller will check it.
>
>> +            info->flags &= ~MSG_ZEROCOPY;
>> +        } else {
>> +            if ((info->flags & MSG_ZEROCOPY) &&
>> +                !virtio_transport_can_zcopy(&info->msg->msg_iter, len)) {
>
> This part is not very clear, I think virtio_transport_can_zcopy()
> should return `true` if "can_zcopy".
>
>> +                use_zcopy = true;
>> +            }
>> +        }
>> +    }
>>
>> -    skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL);
>> +    /* For MSG_ZEROCOPY length will be added later. */
>> +    skb = virtio_vsock_alloc_skb(skb_len + (use_zcopy ? 0 : len), GFP_KERNEL);
>
> I think is better to adsjut `skb_len` in the previous block, when we set
> `use_zcopy = true`, we can do `skb_len -= len` (with the comment);
>
>>     if (!skb)
>>         return NULL;
>>
>> @@ -70,18 +249,15 @@ virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
>>     hdr->len    = cpu_to_le32(len);
>>
>>     if (info->msg && len > 0) {
>> -        payload = skb_put(skb, len);
>> -        err = memcpy_from_msg(payload, info->msg, len);
>> -        if (err)
>> -            goto out;
>> +        int err;
>>
>> -        if (msg_data_left(info->msg) == 0 &&
>> -            info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
>> -            hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
>> +        if (use_zcopy)
>> +            err = virtio_transport_fill_nonlinear_skb(skb, vsk, info);
>> +        else
>> +            err = virtio_transport_copy_payload(skb, vsk, info, len);
>>
>> -            if (info->msg->msg_flags & MSG_EOR)
>> -                hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
>> -        }
>> +        if (err)
>> +            goto out;
>>     }
>>
>>     if (info->reply)
>> @@ -266,7 +442,8 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
>>
>>     skb = virtio_transport_alloc_skb(info, pkt_len,
>>                      src_cid, src_port,
>> -                     dst_cid, dst_port);
>> +                     dst_cid, dst_port,
>> +                     vsk);
>>     if (!skb) {
>>         virtio_transport_put_credit(vvs, pkt_len);
>>         return -ENOMEM;
>> @@ -842,6 +1019,7 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk,
>>         .msg = msg,
>>         .pkt_len = len,
>>         .vsk = vsk,
>> +        .flags = msg->msg_flags,
>>     };
>>
>>     return virtio_transport_send_pkt_info(vsk, &info);
>> @@ -894,7 +1072,7 @@ static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
>>                        le64_to_cpu(hdr->dst_cid),
>>                        le32_to_cpu(hdr->dst_port),
>>                        le64_to_cpu(hdr->src_cid),
>> -                       le32_to_cpu(hdr->src_port));
>> +                       le32_to_cpu(hdr->src_port), NULL);
>>     if (!reply)
>>         return -ENOMEM;
>>
>> -- 
>> 2.25.1
>