From: Andrea Parri (Microsoft) <parri.andrea@xxxxxxxxx> Sent: Monday, March 28, 2022 7:43 AM > > The function can be used to send a VMbus packet and retrieve the > corresponding transaction ID. It will be used by hv_pci. > > No functional change. > > Suggested-by: Michael Kelley <mikelley@xxxxxxxxxxxxx> > Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@xxxxxxxxx> > --- > drivers/hv/channel.c | 38 ++++++++++++++++++++++++++++++++------ > drivers/hv/hyperv_vmbus.h | 2 +- > drivers/hv/ring_buffer.c | 4 +++- > include/linux/hyperv.h | 7 +++++++ > 4 files changed, 43 insertions(+), 8 deletions(-) > > diff --git a/drivers/hv/channel.c b/drivers/hv/channel.c > index a253eee3aeb1a..3eaa41c7ce15f 100644 > --- a/drivers/hv/channel.c > +++ b/drivers/hv/channel.c > @@ -1022,11 +1022,13 @@ void vmbus_close(struct vmbus_channel *channel) > EXPORT_SYMBOL_GPL(vmbus_close); > > /** > - * vmbus_sendpacket() - Send the specified buffer on the given channel > + * vmbus_sendpacket_getid() - Send the specified buffer on the given channel > * @channel: Pointer to vmbus_channel structure > * @buffer: Pointer to the buffer you want to send the data from. > * @bufferlen: Maximum size of what the buffer holds. > * @requestid: Identifier of the request > + * @trans_id: Identifier of the transaction associated to this request, if > + * the send is successful; undefined, otherwise. > * @type: Type of packet that is being sent e.g. negotiate, time > * packet etc. > * @flags: 0 or VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED > @@ -1036,8 +1038,8 @@ EXPORT_SYMBOL_GPL(vmbus_close); > * > * Mainly used by Hyper-V drivers. > */ > -int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer, > - u32 bufferlen, u64 requestid, > +int vmbus_sendpacket_getid(struct vmbus_channel *channel, void *buffer, > + u32 bufferlen, u64 requestid, u64 *trans_id, > enum vmbus_packet_type type, u32 flags) > { > struct vmpacket_descriptor desc; > @@ -1063,7 +1065,31 @@ int vmbus_sendpacket(struct vmbus_channel *channel, > void *buffer, > bufferlist[2].iov_base = &aligned_data; > bufferlist[2].iov_len = (packetlen_aligned - packetlen); > > - return hv_ringbuffer_write(channel, bufferlist, num_vecs, requestid); > + return hv_ringbuffer_write(channel, bufferlist, num_vecs, requestid, trans_id); > +} > +EXPORT_SYMBOL(vmbus_sendpacket_getid); > + > +/** > + * vmbus_sendpacket() - Send the specified buffer on the given channel > + * @channel: Pointer to vmbus_channel structure > + * @buffer: Pointer to the buffer you want to send the data from. > + * @bufferlen: Maximum size of what the buffer holds. > + * @requestid: Identifier of the request > + * @type: Type of packet that is being sent e.g. negotiate, time > + * packet etc. > + * @flags: 0 or VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED > + * > + * Sends data in @buffer directly to Hyper-V via the vmbus. > + * This will send the data unparsed to Hyper-V. > + * > + * Mainly used by Hyper-V drivers. > + */ > +int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer, > + u32 bufferlen, u64 requestid, > + enum vmbus_packet_type type, u32 flags) > +{ > + return vmbus_sendpacket_getid(channel, buffer, bufferlen, > + requestid, NULL, type, flags); > } > EXPORT_SYMBOL(vmbus_sendpacket); > > @@ -1122,7 +1148,7 @@ int vmbus_sendpacket_pagebuffer(struct vmbus_channel > *channel, > bufferlist[2].iov_base = &aligned_data; > bufferlist[2].iov_len = (packetlen_aligned - packetlen); > > - return hv_ringbuffer_write(channel, bufferlist, 3, requestid); > + return hv_ringbuffer_write(channel, bufferlist, 3, requestid, NULL); > } > EXPORT_SYMBOL_GPL(vmbus_sendpacket_pagebuffer); > > @@ -1160,7 +1186,7 @@ int vmbus_sendpacket_mpb_desc(struct vmbus_channel > *channel, > bufferlist[2].iov_base = &aligned_data; > bufferlist[2].iov_len = (packetlen_aligned - packetlen); > > - return hv_ringbuffer_write(channel, bufferlist, 3, requestid); > + return hv_ringbuffer_write(channel, bufferlist, 3, requestid, NULL); > } > EXPORT_SYMBOL_GPL(vmbus_sendpacket_mpb_desc); > > diff --git a/drivers/hv/hyperv_vmbus.h b/drivers/hv/hyperv_vmbus.h > index 3a1f007b678a0..64c0b9cbe183b 100644 > --- a/drivers/hv/hyperv_vmbus.h > +++ b/drivers/hv/hyperv_vmbus.h > @@ -181,7 +181,7 @@ void hv_ringbuffer_cleanup(struct hv_ring_buffer_info > *ring_info); > > int hv_ringbuffer_write(struct vmbus_channel *channel, > const struct kvec *kv_list, u32 kv_count, > - u64 requestid); > + u64 requestid, u64 *trans_id); > > int hv_ringbuffer_read(struct vmbus_channel *channel, > void *buffer, u32 buflen, u32 *buffer_actual_len, > diff --git a/drivers/hv/ring_buffer.c b/drivers/hv/ring_buffer.c > index 71efacb909659..c8561c80c460c 100644 > --- a/drivers/hv/ring_buffer.c > +++ b/drivers/hv/ring_buffer.c > @@ -283,7 +283,7 @@ void hv_ringbuffer_cleanup(struct hv_ring_buffer_info > *ring_info) > /* Write to the ring buffer. */ > int hv_ringbuffer_write(struct vmbus_channel *channel, > const struct kvec *kv_list, u32 kv_count, > - u64 requestid) > + u64 requestid, u64 *trans_id) > { > int i; > u32 bytes_avail_towrite; > @@ -354,6 +354,8 @@ int hv_ringbuffer_write(struct vmbus_channel *channel, > } > desc = hv_get_ring_buffer(outring_info) + old_write; > desc->trans_id = (rqst_id == VMBUS_NO_RQSTOR) ? requestid : rqst_id; > + if (trans_id) > + *trans_id = desc->trans_id; This line should *not* read the trans_id out of the ring buffer, since that memory is shared with the Hyper-V host and subject to being maliciously changed by the host. Need to set *trans_id only from local variables, and somehow ensure the compiler doesn't generate code that reads the value from the ring buffer. Maybe mark the desc->trans_id field as volatile, or cast it as such? Or does WRITE_ONCE() work when setting it? Michael > > /* Set previous packet start */ > prev_indices = hv_get_ring_bufferindices(outring_info); > diff --git a/include/linux/hyperv.h b/include/linux/hyperv.h > index fe2e0179ed51e..a7cb596d893b1 100644 > --- a/include/linux/hyperv.h > +++ b/include/linux/hyperv.h > @@ -1161,6 +1161,13 @@ extern int vmbus_open(struct vmbus_channel *channel, > > extern void vmbus_close(struct vmbus_channel *channel); > > +extern int vmbus_sendpacket_getid(struct vmbus_channel *channel, > + void *buffer, > + u32 bufferLen, > + u64 requestid, > + u64 *trans_id, > + enum vmbus_packet_type type, > + u32 flags); > extern int vmbus_sendpacket(struct vmbus_channel *channel, > void *buffer, > u32 bufferLen, > -- > 2.25.1