Signed-off-by: Jason Wang <jasowang@xxxxxxxxxx> --- drivers/vhost/net.c | 13 +- drivers/vhost/vhost.c | 585 ++++++++++++++++++++++++++++++++++++++++++++++---- drivers/vhost/vhost.h | 13 +- 3 files changed, 566 insertions(+), 45 deletions(-) diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index 30273ad..4991aa4 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -71,7 +71,8 @@ enum { VHOST_NET_FEATURES = VHOST_FEATURES | (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) | (1ULL << VIRTIO_NET_F_MRG_RXBUF) | - (1ULL << VIRTIO_F_IOMMU_PLATFORM) + (1ULL << VIRTIO_F_IOMMU_PLATFORM) | + (1ULL << VIRTIO_F_RING_PACKED) }; enum { @@ -576,7 +577,7 @@ static void handle_tx(struct vhost_net *net) nvq->upend_idx = ((unsigned)nvq->upend_idx - 1) % UIO_MAXIOV; } - vhost_discard_vq_desc(vq, 1); + vhost_discard_vq_desc(vq, &used, 1); vhost_net_enable_vq(net, vq); break; } @@ -714,9 +715,11 @@ static void handle_rx(struct vhost_net *net) mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF); while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk))) { + struct vhost_used_elem *used = vq->heads + nheads; + sock_len += sock_hlen; vhost_len = sock_len + vhost_hlen; - err = vhost_get_bufs(vq, vq->heads + nheads, vhost_len, + err = vhost_get_bufs(vq, used, vhost_len, &in, vq_log, &log, likely(mergeable) ? UIO_MAXIOV : 1, &headcount); @@ -762,7 +765,7 @@ static void handle_rx(struct vhost_net *net) if (unlikely(err != sock_len)) { pr_debug("Discarded rx packet: " " len %d, expected %zd\n", err, sock_len); - vhost_discard_vq_desc(vq, headcount); + vhost_discard_vq_desc(vq, used, 1); continue; } /* Supply virtio_net_hdr if VHOST_NET_F_VIRTIO_NET_HDR */ @@ -786,7 +789,7 @@ static void handle_rx(struct vhost_net *net) copy_to_iter(&num_buffers, sizeof num_buffers, &fixup) != sizeof num_buffers) { vq_err(vq, "Failed num_buffers write"); - vhost_discard_vq_desc(vq, headcount); + vhost_discard_vq_desc(vq, used, 1); goto out; } nheads += headcount; diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 4031a8f..a36e5ad2 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -323,6 +323,9 @@ static void vhost_vq_reset(struct vhost_dev *dev, vhost_reset_is_le(vq); vhost_disable_cross_endian(vq); vq->busyloop_timeout = 0; + vq->used_wrap_counter = true; + vq->last_avail_wrap_counter = true; + vq->avail_wrap_counter = true; vq->umem = NULL; vq->iotlb = NULL; __vhost_vq_meta_reset(vq); @@ -1103,11 +1106,22 @@ static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access) return 0; } -static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, - struct vring_desc __user *desc, - struct vring_avail __user *avail, - struct vring_used __user *used) +static int vq_access_ok_packed(struct vhost_virtqueue *vq, unsigned int num, + struct vring_desc __user *desc, + struct vring_avail __user *avail, + struct vring_used __user *used) +{ + struct vring_desc_packed *packed = (struct vring_desc_packed *)desc; + + /* FIXME: check device area and driver area */ + return access_ok(VERIFY_READ, packed, num * sizeof(*packed)) && + access_ok(VERIFY_WRITE, packed, num * sizeof(*packed)); +} +static int vq_access_ok_split(struct vhost_virtqueue *vq, unsigned int num, + struct vring_desc __user *desc, + struct vring_avail __user *avail, + struct vring_used __user *used) { size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; @@ -1118,6 +1132,17 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, sizeof *used + num * sizeof *used->ring + s); } +static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, + struct vring_desc __user *desc, + struct vring_avail __user *avail, + struct vring_used __user *used) +{ + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + return vq_access_ok_packed(vq, num, desc, avail, used); + else + return vq_access_ok_split(vq, num, desc, avail, used); +} + static void vhost_vq_meta_update(struct vhost_virtqueue *vq, const struct vhost_umem_node *node, int type) @@ -1361,6 +1386,10 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg break; } vq->last_avail_idx = s.num; + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) { + vq->last_avail_wrap_counter = s.num >> 31; + vq->avail_wrap_counter = vq->last_avail_wrap_counter; + } /* Forget the cached index value. */ vq->avail_idx = vq->last_avail_idx; break; @@ -1369,6 +1398,8 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg s.num = vq->last_avail_idx; if (copy_to_user(argp, &s, sizeof s)) r = -EFAULT; + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + s.num |= vq->last_avail_wrap_counter << 31; break; case VHOST_SET_VRING_ADDR: if (copy_from_user(&a, argp, sizeof a)) { @@ -1730,6 +1761,9 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) vhost_init_is_le(vq); + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + return 0; + r = vhost_update_used_flags(vq); if (r) goto err; @@ -1803,7 +1837,8 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, /* Each buffer in the virtqueues is actually a chain of descriptors. This * function returns the next descriptor in the chain, * or -1U if we're at the end. */ -static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc) +static unsigned next_desc_split(struct vhost_virtqueue *vq, + struct vring_desc *desc) { unsigned int next; @@ -1816,11 +1851,17 @@ static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc) return next; } -static int get_indirect(struct vhost_virtqueue *vq, - struct iovec iov[], unsigned int iov_size, - unsigned int *out_num, unsigned int *in_num, - struct vhost_log *log, unsigned int *log_num, - struct vring_desc *indirect) +static unsigned next_desc_packed(struct vhost_virtqueue *vq, + struct vring_desc_packed *desc) +{ + return desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT); +} + +static int get_indirect_split(struct vhost_virtqueue *vq, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, unsigned int *log_num, + struct vring_desc *indirect) { struct vring_desc desc; unsigned int i = 0, count, found = 0; @@ -1910,23 +1951,301 @@ static int get_indirect(struct vhost_virtqueue *vq, } *out_num += ret; } - } while ((i = next_desc(vq, &desc)) != -1); + } while ((i = next_desc_split(vq, &desc)) != -1); return 0; } -/* This looks in the virtqueue and for the first available buffer, and converts - * it to an iovec for convenient access. Since descriptors consist of some - * number of output then some number of input descriptors, it's actually two - * iovecs, but we pack them into one and note how many of each there were. - * - * This function returns the descriptor number found, or vq->num (which is - * never a valid descriptor number) if none was found. A negative code is - * returned on error. */ -int vhost_get_vq_desc(struct vhost_virtqueue *vq, - struct vhost_used_elem *used, - struct iovec iov[], unsigned int iov_size, - unsigned int *out_num, unsigned int *in_num, - struct vhost_log *log, unsigned int *log_num) +static int get_indirect_packed(struct vhost_virtqueue *vq, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, unsigned int *log_num, + struct vring_desc_packed *indirect) +{ + struct vring_desc_packed desc; + unsigned int i = 0, count, found = 0; + u32 len = vhost32_to_cpu(vq, indirect->len); + struct iov_iter from; + int ret, access; + + /* Sanity check */ + if (unlikely(len % sizeof(desc))) { + vq_err(vq, "Invalid length in indirect descriptor: " + "len 0x%llx not multiple of 0x%zx\n", + (unsigned long long)len, + sizeof desc); + return -EINVAL; + } + + ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), + len, vq->indirect, + UIO_MAXIOV, VHOST_ACCESS_RO); + if (unlikely(ret < 0)) { + if (ret != -EAGAIN) + vq_err(vq, "Translation failure %d in indirect.\n", + ret); + return ret; + } + iov_iter_init(&from, READ, vq->indirect, ret, len); + + /* We will use the result as an address to read from, so most + * architectures only need a compiler barrier here. */ + read_barrier_depends(); + + count = len / sizeof desc; + /* Buffers are chained via a 16 bit next field, so + * we can have at most 2^16 of these. */ + if (unlikely(count > USHRT_MAX + 1)) { + vq_err(vq, "Indirect buffer length too big: %d\n", + indirect->len); + return -E2BIG; + } + + do { + unsigned iov_count = *in_num + *out_num; + if (unlikely(++found > count)) { + vq_err(vq, "Loop detected: last one at %u " + "indirect size %u\n", + i, count); + return -EINVAL; + } + if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), + &from))) { + vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n", + i, (size_t)vhost64_to_cpu(vq, indirect->addr) + + i * sizeof desc); + return -EINVAL; + } + if (unlikely(desc.flags & + cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) { + vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n", + i, (size_t)vhost64_to_cpu(vq, indirect->addr) + + i * sizeof desc); + return -EINVAL; + } + + if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) + access = VHOST_ACCESS_WO; + else + access = VHOST_ACCESS_RO; + + ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), + vhost32_to_cpu(vq, desc.len), + iov + iov_count, + iov_size - iov_count, access); + if (unlikely(ret < 0)) { + if (ret != -EAGAIN) + vq_err(vq, "Translation failure %d " + "indirect idx %d\n", + ret, i); + return ret; + } + /* If this is an input descriptor, increment that count. */ + if (access == VHOST_ACCESS_WO) { + *in_num += ret; + if (unlikely(log)) { + log[*log_num].addr = + vhost64_to_cpu(vq, desc.addr); + log[*log_num].len = + vhost32_to_cpu(vq, desc.len); + ++*log_num; + } + } else { + /* If it's an output descriptor, they're all supposed + * to come before any input descriptors. */ + if (unlikely(*in_num)) { + vq_err(vq, "Indirect descriptor " + "has out after in: idx %d\n", i); + return -EINVAL; + } + *out_num += ret; + } + i++; + } while (next_desc_packed(vq, &desc)); + return 0; +} + +#define DESC_AVAIL (1 << VRING_DESC_F_AVAIL) +#define DESC_USED (1 << VRING_DESC_F_USED) +static bool desc_is_avail(struct vhost_virtqueue *vq, bool wrap_counter, + __virtio16 flags) +{ + bool avail = flags & cpu_to_vhost16(vq, DESC_AVAIL); + + return avail == wrap_counter; +} + +static __virtio16 get_desc_flags(struct vhost_virtqueue *vq, bool write) +{ + __virtio16 flags = 0; + + if (vq->used_wrap_counter) { + flags |= cpu_to_vhost16(vq, DESC_AVAIL); + flags |= cpu_to_vhost16(vq, DESC_USED); + } else { + flags &= ~cpu_to_vhost16(vq, DESC_AVAIL); + flags &= ~cpu_to_vhost16(vq, DESC_USED); + } + + if (write) + flags |= cpu_to_vhost16(vq, VRING_DESC_F_WRITE); + + return flags; +} + +static bool vhost_vring_packed_need_event(struct vhost_virtqueue *vq, + bool wrap, __u16 off_wrap, __u16 new, + __u16 old) +{ + int off = off_wrap & ~(1 << 15); + + if (new < old) { + new += vq->num; + wrap ^= 1; + } + + if (wrap != off_wrap >> 15) + off += vq->num; + + return vring_need_event(off, new, old); +} + +static int vhost_get_vq_desc_packed(struct vhost_virtqueue *vq, + struct vhost_used_elem *used, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, + unsigned int *log_num) +{ + struct vring_desc_packed desc; + int ret, access, i; + u16 last_avail_idx = vq->last_avail_idx; + u16 off_wrap = vq->avail_idx | (vq->avail_wrap_counter << 15); + + /* When we start there are none of either input nor output. */ + *out_num = *in_num = 0; + if (unlikely(log)) + *log_num = 0; + + used->count = 0; + + do { + struct vring_desc_packed *d = vq->desc_packed + + vq->last_avail_idx; + unsigned int iov_count = *in_num + *out_num; + + ret = vhost_get_user(vq, desc.flags, &d->flags, + VHOST_ADDR_DESC); + if (unlikely(ret)) { + vq_err(vq, "Failed to get flags: idx %d addr %p\n", + vq->last_avail_idx, &d->flags); + return -EFAULT; + } + + if (!desc_is_avail(vq, vq->last_avail_wrap_counter, desc.flags)) { + /* If there's nothing new since last we looked, return + * invalid. + */ + if (!used->count) + return -ENOSPC; + vq_err(vq, "Unexpected unavail descriptor: idx %d\n", + vq->last_avail_idx); + return -EFAULT; + } + + /* Read desc content after we're sure it was available. */ + smp_rmb(); + + ret = vhost_copy_from_user(vq, &desc, d, sizeof(desc)); + if (unlikely(ret)) { + vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", + vq->last_avail_idx, d); + return -EFAULT; + } + + used->elem.id = desc.id; + + if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) { + ret = get_indirect_packed(vq, iov, iov_size, + out_num, in_num, log, + log_num, &desc); + if (unlikely(ret < 0)) { + if (ret != -EAGAIN) + vq_err(vq, "Failure detected " + "in indirect descriptor " + "at idx %d\n", i); + return ret; + } + goto next; + } + + if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) + access = VHOST_ACCESS_WO; + else + access = VHOST_ACCESS_RO; + ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), + vhost32_to_cpu(vq, desc.len), + iov + iov_count, iov_size - iov_count, + access); + if (unlikely(ret < 0)) { + if (ret != -EAGAIN) + vq_err(vq, "Translation failure %d idx %d\n", + ret, i); + return ret; + } + + if (access == VHOST_ACCESS_WO) { + /* If this is an input descriptor, + * increment that count. + */ + *in_num += ret; + if (unlikely(log)) { + log[*log_num].addr = + vhost64_to_cpu(vq, desc.addr); + log[*log_num].len = + vhost32_to_cpu(vq, desc.len); + ++*log_num; + } + } else { + /* If it's an output descriptor, they're all supposed + * to come before any input descriptors. + */ + if (unlikely(*in_num)) { + vq_err(vq, "Desc out after in: idx %d\n", + i); + return -EINVAL; + } + *out_num += ret; + } + +next: + if (unlikely(++used->count > vq->num)) { + vq_err(vq, "Loop detected: last one at %u " + "vq size %u head %u\n", + i, vq->num, used->elem.id); + return -EINVAL; + } + if (++vq->last_avail_idx >= vq->num) { + vq->last_avail_idx = 0; + vq->last_avail_wrap_counter ^= 1; + } + /* If this descriptor says it doesn't chain, we're done. */ + } while (next_desc_packed(vq, &desc)); + + if (vhost_vring_packed_need_event(vq, vq->last_avail_wrap_counter, + off_wrap, vq->last_avail_idx, + last_avail_idx)) { + vq->avail_idx = vq->last_avail_idx; + vq->avail_wrap_counter = vq->last_avail_wrap_counter; + } + + return 0; +} + +static int vhost_get_vq_desc_split(struct vhost_virtqueue *vq, + struct vhost_used_elem *used, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, unsigned int *log_num) { struct vring_desc desc; unsigned int i, head, found = 0; @@ -2011,9 +2330,9 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, return -EFAULT; } if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) { - ret = get_indirect(vq, iov, iov_size, - out_num, in_num, - log, log_num, &desc); + ret = get_indirect_split(vq, iov, iov_size, + out_num, in_num, + log, log_num, &desc); if (unlikely(ret < 0)) { if (ret != -EAGAIN) vq_err(vq, "Failure detected " @@ -2055,7 +2374,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, } *out_num += ret; } - } while ((i = next_desc(vq, &desc)) != -1); + } while ((i = next_desc_split(vq, &desc)) != -1); /* On success, increment avail index. */ vq->last_avail_idx++; @@ -2065,6 +2384,31 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY)); return 0; } + +/* This looks in the virtqueue and for the first available buffer, and converts + * it to an iovec for convenient access. Since descriptors consist of some + * number of output then some number of input descriptors, it's actually two + * iovecs, but we pack them into one and note how many of each there were. + * + * This function returns the descriptor number found, or vq->num (which is + * never a valid descriptor number) if none was found. A negative code is + * returned on error. + */ +int vhost_get_vq_desc(struct vhost_virtqueue *vq, + struct vhost_used_elem *used, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, unsigned int *log_num) +{ + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + return vhost_get_vq_desc_packed(vq, used, iov, iov_size, + out_num, in_num, + log, log_num); + else + return vhost_get_vq_desc_split(vq, used, iov, iov_size, + out_num, in_num, + log, log_num); +} EXPORT_SYMBOL_GPL(vhost_get_vq_desc); void vhost_set_used_len(struct vhost_virtqueue *vq, @@ -2151,15 +2495,30 @@ int vhost_get_bufs(struct vhost_virtqueue *vq, *count = headcount; return 0; err: - vhost_discard_vq_desc(vq, headcount); + vhost_discard_vq_desc(vq, heads, headcount); return r; } EXPORT_SYMBOL_GPL(vhost_get_bufs); /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ -void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) +void vhost_discard_vq_desc(struct vhost_virtqueue *vq, + struct vhost_used_elem *heads, + int headcount) { - vq->last_avail_idx -= n; + int i; + + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) { + for (i = 0; i < headcount; i++) { + vq->last_avail_idx -= heads[i].count; + if (vq->last_avail_idx >= vq->num) { + vq->last_avail_wrap_counter ^= 1; + vq->last_avail_idx += vq->num; + } + } + } else { + vq->last_avail_idx -= headcount; + } + } EXPORT_SYMBOL_GPL(vhost_discard_vq_desc); @@ -2215,10 +2574,69 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, return 0; } +static int vhost_add_used_n_packed(struct vhost_virtqueue *vq, + struct vhost_used_elem *heads, + unsigned int count) +{ + struct vring_desc_packed __user *desc; + int i, ret; + + for (i = 0; i < count; i++) { + desc = vq->desc_packed + vq->last_used_idx; + + ret = vhost_put_user(vq, heads[i].elem.id, &desc->id, + VHOST_ADDR_DESC); + if (unlikely(ret)) { + vq_err(vq, "Failed to update id: idx %d addr %p\n", + vq->last_used_idx, desc); + return -EFAULT; + } + ret = vhost_put_user(vq, heads[i].elem.len, &desc->len, + VHOST_ADDR_DESC); + if (unlikely(ret)) { + vq_err(vq, "Failed to update len: idx %d addr %p\n", + vq->last_used_idx, desc); + return -EFAULT; + } + + /* Update flags after descriptor id and len is wrote, + * TODO: Update head flags at last for saving barriers */ + smp_wmb(); + + ret = vhost_put_user(vq, get_desc_flags(vq, heads[i].elem.len), + &desc->flags, VHOST_ADDR_DESC); + if (unlikely(ret)) { + vq_err(vq, "Failed to update flags: idx %d addr %p\n", + vq->last_used_idx, desc); + return -EFAULT; + } + + if (unlikely(vq->log_used)) { + /* Make sure desc is written before update log. */ + smp_wmb(); + log_write(vq->log_base, vq->log_addr + + vq->last_used_idx * sizeof(*desc), + sizeof(*desc)); + if (vq->log_ctx) + eventfd_signal(vq->log_ctx, 1); + } + + vq->last_used_idx += heads[i].count; + if (vq->last_used_idx >= vq->num) { + vq->used_wrap_counter ^= 1; + vq->last_used_idx -= vq->num; + } + } + + return 0; +} + /* After we've used one of their buffers, we tell them about it. We'll then * want to notify the guest, using eventfd. */ -int vhost_add_used_n(struct vhost_virtqueue *vq, struct vhost_used_elem *heads, - unsigned count) +static int vhost_add_used_n_split(struct vhost_virtqueue *vq, + struct vhost_used_elem *heads, + unsigned count) + { int start, n, r; @@ -2250,6 +2668,19 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vhost_used_elem *heads, } return r; } + +/* After we've used one of their buffers, we tell them about it. We'll then + * want to notify the guest, using eventfd. + */ +int vhost_add_used_n(struct vhost_virtqueue *vq, + struct vhost_used_elem *heads, + unsigned int count) +{ + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + return vhost_add_used_n_packed(vq, heads, count); + else + return vhost_add_used_n_split(vq, heads, count); +} EXPORT_SYMBOL_GPL(vhost_add_used_n); static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) @@ -2257,6 +2688,11 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) __u16 old, new; __virtio16 event; bool v; + + /* FIXME: check driver area */ + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + return true; + /* Flush out used index updates. This is paired * with the barrier that the Guest executes when enabling * interrupts. */ @@ -2319,7 +2755,8 @@ void vhost_add_used_and_signal_n(struct vhost_dev *dev, EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n); /* return true if we're sure that avaiable ring is empty */ -bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) +static bool vhost_vq_avail_empty_split(struct vhost_dev *dev, + struct vhost_virtqueue *vq) { __virtio16 avail_idx; int r; @@ -2334,10 +2771,58 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) return vq->avail_idx == vq->last_avail_idx; } + +static bool vhost_vq_avail_empty_packed(struct vhost_dev *dev, + struct vhost_virtqueue *vq) +{ + struct vring_desc_packed *d = vq->desc_packed + vq->avail_idx; + __virtio16 flags; + int ret; + + ret = vhost_get_user(vq, flags, &d->flags, VHOST_ADDR_DESC); + if (unlikely(ret)) { + vq_err(vq, "Failed to get flags: idx %d addr %p\n", + vq->last_avail_idx, d); + return -EFAULT; + } + + return !desc_is_avail(vq, vq->avail_wrap_counter, flags); +} + +bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) +{ + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + return vhost_vq_avail_empty_packed(dev, vq); + else + return vhost_vq_avail_empty_split(dev, vq); +} EXPORT_SYMBOL_GPL(vhost_vq_avail_empty); -/* OK, now we need to know about added descriptors. */ -bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) +static bool vhost_enable_notify_packed(struct vhost_dev *dev, + struct vhost_virtqueue *vq) +{ + struct vring_desc_packed *d = vq->desc_packed + vq->avail_idx; + __virtio16 flags; + int ret; + + /* FIXME: disable notification through device area */ + + /* They could have slipped one in as we were doing that: make + * sure it's written, then check again. */ + smp_mb(); + + ret = vhost_get_user(vq, flags, &d->flags, VHOST_ADDR_DESC); + if (unlikely(ret)) { + vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", + vq->last_avail_idx, &d->flags); + return -EFAULT; + } + + return desc_is_avail(vq, vq->avail_wrap_counter, flags); +} + +static bool vhost_enable_notify_split(struct vhost_dev *dev, + struct vhost_virtqueue *vq) { __virtio16 avail_idx; int r; @@ -2372,10 +2857,25 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) return vhost16_to_cpu(vq, avail_idx) != vq->avail_idx; } + +/* OK, now we need to know about added descriptors. */ +bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) +{ + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + return vhost_enable_notify_packed(dev, vq); + else + return vhost_enable_notify_split(dev, vq); +} EXPORT_SYMBOL_GPL(vhost_enable_notify); -/* We don't need to be notified again. */ -void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) +static void vhost_disable_notify_packed(struct vhost_dev *dev, + struct vhost_virtqueue *vq) +{ + /* FIXME: disable notification through device area */ +} + +static void vhost_disable_notify_split(struct vhost_dev *dev, + struct vhost_virtqueue *vq) { int r; @@ -2389,6 +2889,15 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) &vq->used->flags, r); } } + +/* We don't need to be notified again. */ +void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) +{ + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + return vhost_disable_notify_packed(dev, vq); + else + return vhost_disable_notify_split(dev, vq); +} EXPORT_SYMBOL_GPL(vhost_disable_notify); /* Create a new message. */ diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 604821b..7543a46 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -36,6 +36,7 @@ struct vhost_poll { struct vhost_used_elem { struct vring_used_elem elem; + int count; }; void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn); @@ -91,7 +92,10 @@ struct vhost_virtqueue { /* The actual ring of buffers. */ struct mutex mutex; unsigned int num; - struct vring_desc __user *desc; + union { + struct vring_desc __user *desc; + struct vring_desc_packed __user *desc_packed; + }; struct vring_avail __user *avail; struct vring_used __user *used; const struct vhost_umem_node *meta_iotlb[VHOST_NUM_ADDRS]; @@ -148,6 +152,9 @@ struct vhost_virtqueue { bool user_be; #endif u32 busyloop_timeout; + bool used_wrap_counter; + bool avail_wrap_counter; + bool last_avail_wrap_counter; }; struct vhost_msg_node { @@ -203,7 +210,9 @@ void vhost_set_used_len(struct vhost_virtqueue *vq, int len); int vhost_get_used_len(struct vhost_virtqueue *vq, struct vhost_used_elem *used); -void vhost_discard_vq_desc(struct vhost_virtqueue *, int n); +void vhost_discard_vq_desc(struct vhost_virtqueue *, + struct vhost_used_elem *, + int n); int vhost_vq_init_access(struct vhost_virtqueue *); int vhost_add_used(struct vhost_virtqueue *vq, -- 2.7.4