On 2024-05-10 16:21, Mina Almasry wrote: > +/* On error, returns the -errno. On success, returns number of bytes sent to the > + * user. May not consume all of @remaining_len. > + */ > +static int tcp_recvmsg_dmabuf(struct sock *sk, const struct sk_buff *skb, > + unsigned int offset, struct msghdr *msg, > + int remaining_len) > +{ > + struct dmabuf_cmsg dmabuf_cmsg = { 0 }; > + struct tcp_xa_pool tcp_xa_pool; > + unsigned int start; > + int i, copy, n; > + int sent = 0; > + int err = 0; > + > + tcp_xa_pool.max = 0; > + tcp_xa_pool.idx = 0; > + do { > + start = skb_headlen(skb); > + > + if (skb_frags_readable(skb)) { > + err = -ENODEV; > + goto out; > + } > + > + /* Copy header. */ > + copy = start - offset; > + if (copy > 0) { > + copy = min(copy, remaining_len); > + > + n = copy_to_iter(skb->data + offset, copy, > + &msg->msg_iter); > + if (n != copy) { > + err = -EFAULT; > + goto out; > + } > + > + offset += copy; > + remaining_len -= copy; > + > + /* First a dmabuf_cmsg for # bytes copied to user > + * buffer. > + */ > + memset(&dmabuf_cmsg, 0, sizeof(dmabuf_cmsg)); > + dmabuf_cmsg.frag_size = copy; > + err = put_cmsg(msg, SOL_SOCKET, SO_DEVMEM_LINEAR, > + sizeof(dmabuf_cmsg), &dmabuf_cmsg); > + if (err || msg->msg_flags & MSG_CTRUNC) { > + msg->msg_flags &= ~MSG_CTRUNC; > + if (!err) > + err = -ETOOSMALL; > + goto out; > + } > + > + sent += copy; > + > + if (remaining_len == 0) > + goto out; > + } > + > + /* after that, send information of dmabuf pages through a > + * sequence of cmsg > + */ > + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { > + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; > + struct net_iov *niov; > + u64 frag_offset; > + int end; > + > + /* !skb_frags_readable() should indicate that ALL the > + * frags in this skb are dmabuf net_iovs. We're checking > + * for that flag above, but also check individual frags > + * here. If the tcp stack is not setting > + * skb_frags_readable() correctly, we still don't want > + * to crash here. > + */ > + if (!skb_frag_net_iov(frag)) { > + net_err_ratelimited("Found non-dmabuf skb with net_iov"); > + err = -ENODEV; > + goto out; > + } > + > + niov = skb_frag_net_iov(frag); Sorry if we've already discussed this. We have this additional hunk: + if (niov->pp->mp_ops != &dmabuf_devmem_ops) { + err = -ENODEV; + goto out; + } In case one of our skbs end up here, skb_frag_is_net_iov() and !skb_frags_readable(). Does this even matter? And if so then is there a better way to distinguish between our two types of net_iovs?