On 06.11.2022 22:41, Arseniy Krasnov wrote: > This adds transport callback which processes rx queue of socket and > instead of copying data to user provided buffer, it inserts data pages > of each packet to user's vm area. > > Signed-off-by: Arseniy Krasnov <AVKrasnov@xxxxxxxxxxxxxx> > --- > include/linux/virtio_vsock.h | 7 + > include/uapi/linux/virtio_vsock.h | 14 ++ > net/vmw_vsock/virtio_transport_common.c | 244 +++++++++++++++++++++++- > 3 files changed, 261 insertions(+), 4 deletions(-) > > diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h > index c1be40f89a89..d10fdfd8d144 100644 > --- a/include/linux/virtio_vsock.h > +++ b/include/linux/virtio_vsock.h > @@ -37,6 +37,7 @@ struct virtio_vsock_sock { > u32 buf_alloc; > struct list_head rx_queue; > u32 msg_count; > + struct page *usr_poll_page; > }; > > struct virtio_vsock_pkt { > @@ -51,6 +52,7 @@ struct virtio_vsock_pkt { > bool reply; > bool tap_delivered; > bool slab_buf; > + bool split; > }; > > struct virtio_vsock_pkt_info { > @@ -131,6 +133,11 @@ int virtio_transport_dgram_bind(struct vsock_sock *vsk, > struct sockaddr_vm *addr); > bool virtio_transport_dgram_allow(u32 cid, u32 port); > > +int virtio_transport_zerocopy_init(struct vsock_sock *vsk, > + struct vm_area_struct *vma); > +int virtio_transport_zerocopy_dequeue(struct vsock_sock *vsk, > + struct page **pages, > + unsigned long *pages_num); > int virtio_transport_connect(struct vsock_sock *vsk); > > int virtio_transport_shutdown(struct vsock_sock *vsk, int mode); > diff --git a/include/uapi/linux/virtio_vsock.h b/include/uapi/linux/virtio_vsock.h > index 64738838bee5..2a0e4f309918 100644 > --- a/include/uapi/linux/virtio_vsock.h > +++ b/include/uapi/linux/virtio_vsock.h > @@ -66,6 +66,20 @@ struct virtio_vsock_hdr { > __le32 fwd_cnt; > } __attribute__((packed)); > > +struct virtio_vsock_usr_hdr { > + u32 flags; > + u32 len; > +} __attribute__((packed)); > + > +#define VIRTIO_VSOCK_USR_POLL_NO_DATA 0 > +#define VIRTIO_VSOCK_USR_POLL_HAS_DATA 1 > +#define VIRTIO_VSOCK_USR_POLL_SHUTDOWN ~0 > + > +struct virtio_vsock_usr_hdr_pref { > + u32 poll_value; > + u32 hdr_num; > +} __attribute__((packed)); > + > enum virtio_vsock_type { > VIRTIO_VSOCK_TYPE_STREAM = 1, > VIRTIO_VSOCK_TYPE_SEQPACKET = 2, > diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c > index 444764869670..fa4a2688a5d5 100644 > --- a/net/vmw_vsock/virtio_transport_common.c > +++ b/net/vmw_vsock/virtio_transport_common.c > @@ -12,6 +12,7 @@ > #include <linux/ctype.h> > #include <linux/list.h> > #include <linux/virtio_vsock.h> > +#include <linux/mm.h> > #include <uapi/linux/vsockmon.h> > > #include <net/sock.h> > @@ -241,6 +242,14 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, > static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, > struct virtio_vsock_pkt *pkt) > { > + if (vvs->usr_poll_page) { > + struct virtio_vsock_usr_hdr_pref *hdr; > + > + hdr = (struct virtio_vsock_usr_hdr_pref *)page_to_virt(vvs->usr_poll_page); > + > + hdr->poll_value = VIRTIO_VSOCK_USR_POLL_HAS_DATA; > + } > + > if (vvs->rx_bytes + pkt->len > vvs->buf_alloc) > return false; > > @@ -253,6 +262,14 @@ static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, > { > vvs->rx_bytes -= pkt->len; > vvs->fwd_cnt += pkt->len; > + > + if (!vvs->rx_bytes && vvs->usr_poll_page) { > + struct virtio_vsock_usr_hdr_pref *hdr; > + > + hdr = (struct virtio_vsock_usr_hdr_pref *)page_to_virt(vvs->usr_poll_page); > + > + hdr->poll_value = VIRTIO_VSOCK_USR_POLL_NO_DATA; > + } > } > > void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) > @@ -347,6 +364,191 @@ virtio_transport_stream_do_peek(struct vsock_sock *vsk, > return err; > } > > +int virtio_transport_zerocopy_init(struct vsock_sock *vsk, > + struct vm_area_struct *vma) > +{ > + struct virtio_vsock_sock *vvs; > + int err = 0; > + > + if (vma->vm_end - vma->vm_start < 2 * PAGE_SIZE) > + return -EINVAL; > + > + vvs = vsk->trans; > + > + spin_lock_bh(&vvs->rx_lock); > + > + if (!vvs->usr_poll_page) { > + /* GFP_ATOMIC because of spinlock. */ > + vvs->usr_poll_page = alloc_page(GFP_KERNEL | GFP_ATOMIC); ^^^ oops, only GFP_ATOMIC is needed > + > + if (!vvs->usr_poll_page) { > + err = -ENOMEM; > + } else { > + struct virtio_vsock_usr_hdr_pref *usr_hdr_pref; > + unsigned long one_page = 1; > + > + usr_hdr_pref = page_to_virt(vvs->usr_poll_page); > + > + if (vsk->peer_shutdown & SHUTDOWN_MASK) { > + usr_hdr_pref->poll_value = VIRTIO_VSOCK_USR_POLL_SHUTDOWN; > + } else { > + usr_hdr_pref->poll_value = vvs->rx_bytes ? > + VIRTIO_VSOCK_USR_POLL_HAS_DATA : > + VIRTIO_VSOCK_USR_POLL_NO_DATA; > + } > + > + usr_hdr_pref->hdr_num = 0; > + > + err = vm_insert_pages(vma, vma->vm_start, > + &vvs->usr_poll_page, > + &one_page); > + > + if (one_page) > + err = -EINVAL; > + } > + } else { > + err = -EINVAL; > + } > + > + spin_unlock_bh(&vvs->rx_lock); > + > + return err; > +} > +EXPORT_SYMBOL_GPL(virtio_transport_zerocopy_init); > + > +int virtio_transport_zerocopy_dequeue(struct vsock_sock *vsk, > + struct page **pages, > + unsigned long *pages_num) > +{ > + struct virtio_vsock_usr_hdr_pref *usr_hdr_pref; > + struct virtio_vsock_usr_hdr *usr_hdr_buffer; > + struct virtio_vsock_sock *vvs; > + unsigned long max_usr_hdrs; > + struct page *usr_hdr_page; > + int pages_cnt; > + > + if (*pages_num < 2) > + return -EINVAL; > + > + vvs = vsk->trans; > + > + max_usr_hdrs = (PAGE_SIZE - sizeof(*usr_hdr_pref)) / sizeof(*usr_hdr_buffer); > + *pages_num = min(max_usr_hdrs, *pages_num); > + pages_cnt = 0; > + > + spin_lock_bh(&vvs->rx_lock); > + > + if (!vvs->usr_poll_page) { > + spin_unlock_bh(&vvs->rx_lock); > + return -EINVAL; > + } > + > + usr_hdr_page = vvs->usr_poll_page; > + usr_hdr_pref = page_to_virt(usr_hdr_page); > + usr_hdr_buffer = (struct virtio_vsock_usr_hdr *)(usr_hdr_pref + 1); > + usr_hdr_pref->hdr_num = 0; > + > + /* If ref counter is 1, then page owned during > + * allocation and not mapped, so insert it to > + * the output array. It will be mapped. > + */ > + if (page_ref_count(usr_hdr_page) == 1) { > + pages[pages_cnt++] = usr_hdr_page; > + /* Inc ref one more, as AF_VSOCK layer calls > + * 'put_page()' for each returned page. > + */ > + get_page(usr_hdr_page); > + } else { > + pages[pages_cnt++] = NULL; > + } > + > + /* Polling page is already mapped. */ > + while (!list_empty(&vvs->rx_queue) && > + pages_cnt < *pages_num) { > + struct virtio_vsock_pkt *pkt; > + ssize_t rest_data_bytes; > + size_t moved_data_bytes; > + unsigned long pg_offs; > + > + pkt = list_first_entry(&vvs->rx_queue, > + struct virtio_vsock_pkt, list); > + > + rest_data_bytes = le32_to_cpu(pkt->hdr.len) - pkt->off; > + > + /* For packets, bigger than one page, split it's > + * high order allocated buffer to 0 order pages. > + * Otherwise 'vm_insert_pages()' will fail, for > + * all pages except first. > + */ > + if (rest_data_bytes > PAGE_SIZE) { > + /* High order buffer not split yet. */ > + if (!pkt->split) { > + split_page(virt_to_page(pkt->buf), > + get_order(le32_to_cpu(pkt->hdr.len))); > + pkt->split = true; > + } > + } > + > + pg_offs = pkt->off; > + moved_data_bytes = 0; > + > + while (rest_data_bytes && > + pages_cnt < *pages_num) { > + struct page *buf_page; > + > + buf_page = virt_to_page(pkt->buf + pg_offs); > + > + pages[pages_cnt++] = buf_page; > + /* Get reference to prevent this page being > + * returned to page allocator when packet will > + * be freed. Ref count will be 2. > + */ > + get_page(buf_page); > + pg_offs += PAGE_SIZE; > + > + if (rest_data_bytes >= PAGE_SIZE) { > + moved_data_bytes += PAGE_SIZE; > + rest_data_bytes -= PAGE_SIZE; > + } else { > + moved_data_bytes += rest_data_bytes; > + rest_data_bytes = 0; > + } > + } > + > + if (!rest_data_bytes) > + usr_hdr_buffer->flags = le32_to_cpu(pkt->hdr.flags); > + else > + usr_hdr_buffer->flags = 0; > + > + usr_hdr_buffer->len = moved_data_bytes; > + > + usr_hdr_buffer++; > + usr_hdr_pref->hdr_num++; > + > + pkt->off = pg_offs; > + > + if (rest_data_bytes == 0) { > + list_del(&pkt->list); > + virtio_transport_dec_rx_pkt(vvs, pkt); > + virtio_transport_free_pkt(pkt); > + } > + > + /* Now ref count for all pages of packet is 1. */ > + } > + > + if (*pages_num - 1 < max_usr_hdrs) > + usr_hdr_buffer->len = 0; > + > + spin_unlock_bh(&vvs->rx_lock); > + > + virtio_transport_send_credit_update(vsk); > + > + *pages_num = pages_cnt; > + > + return 0; > +} > +EXPORT_SYMBOL_GPL(virtio_transport_zerocopy_dequeue); > + > static ssize_t > virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, > struct msghdr *msg, > @@ -969,11 +1171,21 @@ void virtio_transport_release(struct vsock_sock *vsk) > { > struct sock *sk = &vsk->sk; > bool remove_sock = true; > + struct virtio_vsock_sock *vvs = vsk->trans; > > if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) > remove_sock = virtio_transport_close(vsk); > > if (remove_sock) { > + spin_lock_bh(&vvs->rx_lock); > + > + if (vvs->usr_poll_page) { > + __free_page(vvs->usr_poll_page); > + vvs->usr_poll_page = NULL; > + } > + > + spin_unlock_bh(&vvs->rx_lock); > + > sock_set_flag(sk, SOCK_DONE); > virtio_transport_remove_sock(vsk); > } > @@ -1077,6 +1289,7 @@ virtio_transport_recv_connected(struct sock *sk, > struct virtio_vsock_pkt *pkt) > { > struct vsock_sock *vsk = vsock_sk(sk); > + struct virtio_vsock_sock *vvs = vsk->trans; > int err = 0; > > switch (le16_to_cpu(pkt->hdr.op)) { > @@ -1095,6 +1308,19 @@ virtio_transport_recv_connected(struct sock *sk, > vsk->peer_shutdown |= RCV_SHUTDOWN; > if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) > vsk->peer_shutdown |= SEND_SHUTDOWN; > + > + spin_lock_bh(&vvs->rx_lock); > + > + if (vvs->usr_poll_page) { > + struct virtio_vsock_usr_hdr_pref *hdr; > + > + hdr = (struct virtio_vsock_usr_hdr_pref *)page_to_virt(vvs->usr_poll_page); > + > + hdr->poll_value = 0xffffffff; > + } > + > + spin_unlock_bh(&vvs->rx_lock); > + > if (vsk->peer_shutdown == SHUTDOWN_MASK && > vsock_stream_has_data(vsk) <= 0 && > !sock_flag(sk, SOCK_DONE)) { > @@ -1343,11 +1569,21 @@ EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); > void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) > { > if (pkt->buf_len) { > - if (pkt->slab_buf) > + if (pkt->slab_buf) { > kvfree(pkt->buf); > - else > - free_pages((unsigned long)pkt->buf, > - get_order(pkt->buf_len)); > + } else { > + unsigned int order = get_order(pkt->buf_len); > + unsigned long buf = (unsigned long)pkt->buf; > + > + if (pkt->split) { > + int i; > + > + for (i = 0; i < (1 << order); i++) > + free_page(buf + i * PAGE_SIZE); > + } else { > + free_pages(buf, order); > + } > + } > } > > kfree(pkt);