Jakub Kicinski wrote: > On Fri, 1 Nov 2019 10:22:38 -0700, Jakub Kicinski wrote: > > > > + msg->sg.copybreak = 0; > > > > + } else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) > > > > > + sk_msg_iter_dist(msg->sg.end, msg->sg.curr)) { > > > > + sk_msg_iter_var_prev(i); > > > > > > I suspect with small update to dist logic the special case could also > > > be dropped here. But I have a preference for my example above at the > > > moment. Just getting coffee now so will think on it though. > > > > Oka, I like the dist thing, I thought that's where you were going in > > your first email :) > > > > I need to do some more admin, and then I'll probably write a unit test > > for this code (use space version).. So we can test either patch with it. > > Attaching my "unit test", you should be able to just replace > sk_msg_trim() with yours and re-run. That said my understanding of the > expected geometry of the buffer may not be correct :) > > The patch I posted yesterday, with the small adjustment to set curr to > start on empty message passes that test, here it is again: > > ----->8----- > > From 953df5bc0992e31a2c7863ea8b8e490ba7a07356 Mon Sep 17 00:00:00 2001 > From: Jakub Kicinski <jakub.kicinski@xxxxxxxxxxxxx> > Date: Tue, 29 Oct 2019 20:20:49 -0700 > Subject: [PATCH net] net/tls: fix sk_msg trim on fallback to copy mode > > sk_msg_trim() tries to only update curr pointer if it falls into > the trimmed region. The logic, however, does not take into the > account pointer wrapping that sk_msg_iter_var_prev() does nor > (as John points out) the fact that msg->sg is a ring buffer. > > This means that when the message was trimmed completely, the new > curr pointer would have the value of MAX_MSG_FRAGS - 1, which is > neither smaller than any other value, nor would it actually be > correct. > > Special case the trimming to 0 length a little bit and rework > the comparison between curr and end to take into account wrapping. > > This bug caused the TLS code to not copy all of the message, if > zero copy filled in fewer sg entries than memcopy would need. > > Big thanks to Alexander Potapenko for the non-KMSAN reproducer. > > v2: > - take into account that msg->sg is a ring buffer (John). > > Fixes: d829e9c4112b ("tls: convert to generic sk_msg interface") > Suggested-by: John Fastabend <john.fastabend@xxxxxxxxx> > Reported-by: syzbot+f8495bff23a879a6d0bd@xxxxxxxxxxxxxxxxxxxxxxxxx > Reported-by: syzbot+6f50c99e8f6194bf363f@xxxxxxxxxxxxxxxxxxxxxxxxx > Signed-off-by: Jakub Kicinski <jakub.kicinski@xxxxxxxxxxxxx> > --- > include/linux/skmsg.h | 9 ++++++--- > net/core/skmsg.c | 20 +++++++++++++++----- > 2 files changed, 21 insertions(+), 8 deletions(-) > > diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h > index e4b3fb4bb77c..ce7055259877 100644 > --- a/include/linux/skmsg.h > +++ b/include/linux/skmsg.h > @@ -139,6 +139,11 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes) > } > } > > +static inline u32 sk_msg_iter_dist(u32 start, u32 end) > +{ > + return end >= start ? end - start : end + (MAX_MSG_FRAGS - start); > +} > + > #define sk_msg_iter_var_prev(var) \ > do { \ > if (var == 0) \ > @@ -198,9 +203,7 @@ static inline u32 sk_msg_elem_used(const struct sk_msg *msg) > if (sk_msg_full(msg)) > return MAX_MSG_FRAGS; > > - return msg->sg.end >= msg->sg.start ? > - msg->sg.end - msg->sg.start : > - msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start); > + return sk_msg_iter_dist(msg->sg.start, msg->sg.end); > } I think its nice to pull this into a helper so I'm ok with also using the dist below, except for one comment below. > > static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which) > diff --git a/net/core/skmsg.c b/net/core/skmsg.c > index cf390e0aa73d..f87fde3a846c 100644 > --- a/net/core/skmsg.c > +++ b/net/core/skmsg.c > @@ -270,18 +270,28 @@ void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) > > msg->sg.data[i].length -= trim; > sk_mem_uncharge(sk, trim); > + /* Adjust copybreak if it falls into the trimmed part of last buf */ > + if (msg->sg.curr == i && msg->sg.copybreak > msg->sg.data[i].length) > + msg->sg.copybreak = msg->sg.data[i].length; > out: > - /* If we trim data before curr pointer update copybreak and current > - * so that any future copy operations start at new copy location. > + sk_msg_iter_var_next(i); > + msg->sg.end = i; > + > + /* If we trim data a full sg elem before curr pointer update > + * copybreak and current so that any future copy operations > + * start at new copy location. > * However trimed data that has not yet been used in a copy op > * does not require an update. > */ > - if (msg->sg.curr >= i) { > + if (!msg->sg.size) { > + msg->sg.curr = msg->sg.start; > + msg->sg.copybreak = 0; > + } else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) > > + sk_msg_iter_dist(msg->sg.end, msg->sg.curr)) { I'm not seeing how this can work. Taking simple case with start < end so normal geometry without wrapping. Let, start = 1 curr = 3 end = 4 We could trim an index to get, start = 1 curr = 3 i = 3 end = 4 Then after out: label this would push end up one, start = 1 curr = 3 i = 3 end = 4 But dist(start,curr) = 2 and dist(end, curr) = 1 and we would set curr to '3' but clear the copybreak? I think a better comparison would be, if (sk_msg_iter_dist(msg->sg.start, i) < sk_msg_iter_dist(msg->sg.start, msg->sg.curr) To check if 'i' walked past curr so we can reset curr/copybreak? > + sk_msg_iter_var_prev(i); > msg->sg.curr = i; > msg->sg.copybreak = msg->sg.data[i].length; > } > - sk_msg_iter_var_next(i); > - msg->sg.end = i; > } > EXPORT_SYMBOL_GPL(sk_msg_trim); > > -- > 2.23.0 >