On Mon, May 17, 2021 at 05:08:36PM +0800, Xie Yongji wrote: > This adds validation for used length (might come > from an untrusted device) when it will be used by > virtio device driver. > > Signed-off-by: Xie Yongji <xieyongji@xxxxxxxxxxxxx> > --- > drivers/virtio/virtio_ring.c | 22 +++++++++++++++++++--- > 1 file changed, 19 insertions(+), 3 deletions(-) > > diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c > index d999a1d6d271..7d4845d06f21 100644 > --- a/drivers/virtio/virtio_ring.c > +++ b/drivers/virtio/virtio_ring.c > @@ -68,11 +68,13 @@ > struct vring_desc_state_split { > void *data; /* Data for callback. */ > struct vring_desc *indir_desc; /* Indirect descriptor, if any. */ > + u32 in_len; /* Total length of writable buffer */ > }; > > struct vring_desc_state_packed { > void *data; /* Data for callback. */ > struct vring_packed_desc *indir_desc; /* Indirect descriptor, if any. */ > + u32 in_len; /* Total length of writable buffer */ > u16 num; /* Descriptor list length. */ > u16 last; /* The last desc state in a list. */ > }; Hmm for packed it's aligned to 64 bit anyway, so we are not making it any worse. But for split this pushes struct size up by 1/3 increasing cache pressure. > @@ -486,7 +488,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, > struct vring_virtqueue *vq = to_vvq(_vq); > struct scatterlist *sg; > struct vring_desc *desc; > - unsigned int i, n, avail, descs_used, prev, err_idx; > + unsigned int i, n, avail, descs_used, prev, err_idx, in_len = 0; > int head; > bool indirect; > > @@ -570,6 +572,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, > VRING_DESC_F_NEXT | > VRING_DESC_F_WRITE, > indirect); > + in_len += sg->length; > } > } > /* Last one doesn't continue. */ > @@ -604,6 +607,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, > > /* Store token and indirect buffer state. */ > vq->split.desc_state[head].data = data; > + vq->split.desc_state[head].in_len = in_len; > if (indirect) > vq->split.desc_state[head].indir_desc = desc; > else > @@ -784,6 +788,10 @@ static void *virtqueue_get_buf_ctx_split(struct virtqueue *_vq, > BAD_RING(vq, "id %u is not a head!\n", i); > return NULL; > } > + if (unlikely(len && vq->split.desc_state[i].in_len < *len)) { > + BAD_RING(vq, "id %u has invalid length: %u!\n", i, *len); > + return NULL; > + } > > /* detach_buf_split clears data, so grab it now. */ > ret = vq->split.desc_state[i].data; > @@ -1059,7 +1067,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq, > { > struct vring_packed_desc *desc; > struct scatterlist *sg; > - unsigned int i, n, err_idx; > + unsigned int i, n, err_idx, in_len = 0; > u16 head, id; > dma_addr_t addr; > > @@ -1084,6 +1092,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq, > if (vring_mapping_error(vq, addr)) > goto unmap_release; > > + in_len += (n < out_sgs) ? 0 : sg->length; > desc[i].flags = cpu_to_le16(n < out_sgs ? > 0 : VRING_DESC_F_WRITE); > desc[i].addr = cpu_to_le64(addr); > @@ -1141,6 +1150,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq, > vq->packed.desc_state[id].data = data; > vq->packed.desc_state[id].indir_desc = desc; > vq->packed.desc_state[id].last = id; > + vq->packed.desc_state[id].in_len = in_len; > > vq->num_added += 1; > > @@ -1173,7 +1183,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq, > struct vring_virtqueue *vq = to_vvq(_vq); > struct vring_packed_desc *desc; > struct scatterlist *sg; > - unsigned int i, n, c, descs_used, err_idx; > + unsigned int i, n, c, descs_used, err_idx, in_len = 0; > __le16 head_flags, flags; > u16 head, id, prev, curr, avail_used_flags; > > @@ -1223,6 +1233,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq, > if (vring_mapping_error(vq, addr)) > goto unmap_release; > > + in_len += (n < out_sgs) ? 0 : sg->length; > flags = cpu_to_le16(vq->packed.avail_used_flags | > (++c == total_sg ? 0 : VRING_DESC_F_NEXT) | > (n < out_sgs ? 0 : VRING_DESC_F_WRITE)); > @@ -1268,6 +1279,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq, > vq->packed.desc_state[id].data = data; > vq->packed.desc_state[id].indir_desc = ctx; > vq->packed.desc_state[id].last = prev; > + vq->packed.desc_state[id].in_len = in_len; > > /* > * A driver MUST NOT make the first descriptor in the list > @@ -1456,6 +1468,10 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, > BAD_RING(vq, "id %u is not a head!\n", id); > return NULL; > } > + if (unlikely(len && vq->packed.desc_state[id].in_len < *len)) { > + BAD_RING(vq, "id %u has invalid length: %u!\n", id, *len); > + return NULL; > + } > > /* detach_buf_packed clears data, so grab it now. */ > ret = vq->packed.desc_state[id].data; > -- > 2.11.0 _______________________________________________ Virtualization mailing list Virtualization@xxxxxxxxxxxxxxxxxxxxxxxxxx https://lists.linuxfoundation.org/mailman/listinfo/virtualization