Because the dgram sendmsg() path for AF_VSOCK acquires the socket lock it does not scale when many senders share a socket. Prior to this patch the socket lock is used to protect both reads and writes to the local_addr, remote_addr, transport, and buffer size variables of a vsock socket. What follows are the new protection schemes for these fields that ensure a race-free and usually lock-free multi-sender sendmsg() path for vsock dgrams. - local_addr local_addr changes as a result of binding a socket. The write path for local_addr is bind() and various vsock_auto_bind() call sites. After a socket has been bound via vsock_auto_bind() or bind(), subsequent calls to bind()/vsock_auto_bind() do not write to local_addr again. bind() rejects the user request and vsock_auto_bind() early exits. Therefore, the local addr can not change while a parallel thread is in sendmsg() and lock-free reads of local addr in sendmsg() are safe. Change: only acquire lock for auto-binding as-needed in sendmsg(). - buffer size variables Not used by dgram, so they do not need protection. No change. - remote_addr and transport Because a remote_addr update may result in a changed transport, but we would like to be able to read these two fields lock-free but coherently in the vsock send path, this patch packages these two fields into a new struct vsock_remote_info that is referenced by an RCU-protected pointer. Writes are synchronized as usual by the socket lock. Reads only take place in RCU read-side critical sections. When remote_addr or transport is updated, a new remote info is allocated. Old readers still see the old coherent remote_addr/transport pair, and new readers will refer to the new coherent. The coherency between remote_addr and transport previously provided by the socket lock alone is now also preserved by RCU, except with the highly-scalable lock-free read-side. Helpers are introduced for accessing and updating the new pointer. The new structure is contains an rcu_head so that kfree_rcu() can be used. This removes the need of writers to use synchronize_rcu() after freeing old structures which is simply more efficient and reduces code churn where remote_addr/transport are already being updated inside RCU read-side sections. Only virtio has been tested, but updates were necessary to the VMCI and hyperv code. Unfortunately the author does not have access to VMCI/hyperv systems so those changes are untested. Perf Tests (results from patch v2) vCPUS: 16 Threads: 16 Payload: 4KB Test Runs: 5 Type: SOCK_DGRAM Before: 245.2 MB/s After: 509.2 MB/s (+107%) Notably, on the same test system, vsock dgram even outperforms multi-threaded UDP over virtio-net with vhost and MQ support enabled. Throughput metrics for single-threaded SOCK_DGRAM and single/multi-threaded SOCK_STREAM showed no statistically signficant throughput changes (lowest p-value reaching 0.27), with the range of the mean difference ranging between -5% to +1%. Signed-off-by: Bobby Eshleman <bobby.eshleman@xxxxxxxxxxxxx> --- drivers/vhost/vsock.c | 12 +- include/linux/virtio_vsock.h | 3 +- include/net/af_vsock.h | 39 ++- net/vmw_vsock/af_vsock.c | 451 +++++++++++++++++++++++++------- net/vmw_vsock/diag.c | 10 +- net/vmw_vsock/hyperv_transport.c | 27 +- net/vmw_vsock/virtio_transport_common.c | 32 ++- net/vmw_vsock/vmci_transport.c | 84 ++++-- net/vmw_vsock/vsock_bpf.c | 10 +- 9 files changed, 518 insertions(+), 150 deletions(-) diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c index 159c1a22c1a8..b027a780d333 100644 --- a/drivers/vhost/vsock.c +++ b/drivers/vhost/vsock.c @@ -297,13 +297,17 @@ static int vhost_transport_cancel_pkt(struct vsock_sock *vsk) { struct vhost_vsock *vsock; + unsigned int cid; int cnt = 0; int ret = -ENODEV; rcu_read_lock(); + ret = vsock_remote_addr_cid(vsk, &cid); + if (ret < 0) + goto out; /* Find the vhost_vsock according to guest context id */ - vsock = vhost_vsock_get(vsk->remote_addr.svm_cid); + vsock = vhost_vsock_get(cid); if (!vsock) goto out; @@ -706,6 +710,10 @@ static void vhost_vsock_flush(struct vhost_vsock *vsock) static void vhost_vsock_reset_orphans(struct sock *sk) { struct vsock_sock *vsk = vsock_sk(sk); + unsigned int cid; + + if (vsock_remote_addr_cid(vsk, &cid) < 0) + return; /* vmci_transport.c doesn't take sk_lock here either. At least we're * under vsock_table_lock so the sock cannot disappear while we're @@ -713,7 +721,7 @@ static void vhost_vsock_reset_orphans(struct sock *sk) */ /* If the peer is still valid, no need to reset connection */ - if (vhost_vsock_get(vsk->remote_addr.svm_cid)) + if (vhost_vsock_get(cid)) return; /* If the close timeout is pending, let it expire. This avoids races diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h index 237ca87a2ecd..97656e83606f 100644 --- a/include/linux/virtio_vsock.h +++ b/include/linux/virtio_vsock.h @@ -231,7 +231,8 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg, size_t len); int -virtio_transport_dgram_enqueue(struct vsock_sock *vsk, +virtio_transport_dgram_enqueue(const struct vsock_transport *transport, + struct vsock_sock *vsk, struct sockaddr_vm *remote_addr, struct msghdr *msg, size_t len); diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h index c115e655b4f5..84f2a9700ebd 100644 --- a/include/net/af_vsock.h +++ b/include/net/af_vsock.h @@ -25,12 +25,17 @@ extern spinlock_t vsock_table_lock; #define vsock_sk(__sk) ((struct vsock_sock *)__sk) #define sk_vsock(__vsk) (&(__vsk)->sk) +struct vsock_remote_info { + struct sockaddr_vm addr; + struct rcu_head rcu; + const struct vsock_transport *transport; +}; + struct vsock_sock { /* sk must be the first member. */ struct sock sk; - const struct vsock_transport *transport; struct sockaddr_vm local_addr; - struct sockaddr_vm remote_addr; + struct vsock_remote_info * __rcu remote_info; /* Links for the global tables of bound and connected sockets. */ struct list_head bound_table; struct list_head connected_table; @@ -120,8 +125,8 @@ struct vsock_transport { /* DGRAM. */ int (*dgram_bind)(struct vsock_sock *, struct sockaddr_vm *); - int (*dgram_enqueue)(struct vsock_sock *, struct sockaddr_vm *, - struct msghdr *, size_t len); + int (*dgram_enqueue)(const struct vsock_transport *, struct vsock_sock *, + struct sockaddr_vm *, struct msghdr *, size_t len); bool (*dgram_allow)(u32 cid, u32 port); int (*dgram_get_cid)(struct sk_buff *skb, unsigned int *cid); int (*dgram_get_port)(struct sk_buff *skb, unsigned int *port); @@ -196,6 +201,17 @@ void vsock_core_unregister(const struct vsock_transport *t); /* The transport may downcast this to access transport-specific functions */ const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk); +static inline struct vsock_remote_info * +vsock_core_get_remote_info(struct vsock_sock *vsk) +{ + + /* vsk->remote_info may be accessed if the rcu read lock is held OR the + * socket lock is held + */ + return rcu_dereference_check(vsk->remote_info, + lockdep_sock_is_held(sk_vsock(vsk))); +} + /**** UTILS ****/ /* vsock_table_lock must be held */ @@ -214,7 +230,7 @@ void vsock_release_pending(struct sock *pending); void vsock_add_pending(struct sock *listener, struct sock *pending); void vsock_remove_pending(struct sock *listener, struct sock *pending); void vsock_enqueue_accept(struct sock *listener, struct sock *connected); -void vsock_insert_connected(struct vsock_sock *vsk); +int vsock_insert_connected(struct vsock_sock *vsk); void vsock_remove_bound(struct vsock_sock *vsk); void vsock_remove_connected(struct vsock_sock *vsk); struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr); @@ -223,7 +239,8 @@ struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, void vsock_remove_sock(struct vsock_sock *vsk); void vsock_for_each_connected_socket(struct vsock_transport *transport, void (*fn)(struct sock *sk)); -int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk); +int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk, + struct sockaddr_vm *remote_addr); bool vsock_find_cid(unsigned int cid); struct sock *vsock_find_bound_dgram_socket(struct sockaddr_vm *addr); @@ -253,4 +270,14 @@ static inline void __init vsock_bpf_build_proto(void) {} #endif +/* RCU-protected remote addr helpers */ +int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid); +int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port); +int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid, + unsigned int *port); +int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *dest); +bool vsock_remote_addr_bound(struct vsock_sock *vsk); +bool vsock_remote_addr_equals(struct vsock_sock *vsk, struct sockaddr_vm *other); +int vsock_remote_addr_update_cid_port(struct vsock_sock *vsk, u32 cid, u32 port); + #endif /* __AF_VSOCK_H__ */ diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index e8c70069d77d..0520228d2a68 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -114,6 +114,8 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr); static void vsock_sk_destruct(struct sock *sk); static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); +static bool vsock_use_local_transport(unsigned int remote_cid); +static bool sock_type_connectible(u16 type); /* Protocol family. */ struct proto vsock_proto = { @@ -145,6 +147,147 @@ static const struct vsock_transport *transport_local; static DEFINE_MUTEX(vsock_register_mutex); /**** UTILS ****/ +bool vsock_remote_addr_bound(struct vsock_sock *vsk) +{ + struct vsock_remote_info *remote_info; + bool ret; + + rcu_read_lock(); + remote_info = vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return false; + } + + ret = vsock_addr_bound(&remote_info->addr); + rcu_read_unlock(); + + return ret; +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_bound); + +int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *dest) +{ + struct vsock_remote_info *remote_info; + + rcu_read_lock(); + remote_info = vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return -EINVAL; + } + memcpy(dest, &remote_info->addr, sizeof(*dest)); + rcu_read_unlock(); + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_copy); + +int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid) +{ + return vsock_remote_addr_cid_port(vsk, cid, NULL); +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_cid); + +int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port) +{ + return vsock_remote_addr_cid_port(vsk, NULL, port); +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_port); + +int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid, + unsigned int *port) +{ + struct vsock_remote_info *remote_info; + + rcu_read_lock(); + remote_info = vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return -EINVAL; + } + + if (cid) + *cid = remote_info->addr.svm_cid; + if (port) + *port = remote_info->addr.svm_port; + + rcu_read_unlock(); + return 0; +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_cid_port); + +/* The socket lock must be held by the caller */ +int vsock_set_remote_info(struct vsock_sock *vsk, + const struct vsock_transport *transport, + struct sockaddr_vm *addr) +{ + struct vsock_remote_info *old, *new; + + if (addr || transport) { + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (!new) + return -ENOMEM; + + if (addr) + memcpy(&new->addr, addr, sizeof(new->addr)); + + if (transport) + new->transport = transport; + } else { + new = NULL; + } + + old = rcu_replace_pointer(vsk->remote_info, new, lockdep_sock_is_held(sk_vsock(vsk))); + kfree_rcu(old, rcu); + + return 0; +} + +static const struct vsock_transport * +vsock_connectible_lookup_transport(unsigned int cid, __u8 flags) +{ + const struct vsock_transport *transport; + + if (vsock_use_local_transport(cid)) + transport = transport_local; + else if (cid <= VMADDR_CID_HOST || !transport_h2g || + (flags & VMADDR_FLAG_TO_HOST)) + transport = transport_g2h; + else + transport = transport_h2g; + + return transport; +} + +static const struct vsock_transport * +vsock_dgram_lookup_transport(unsigned int cid, __u8 flags) +{ + if (transport_dgram) + return transport_dgram; + + return vsock_connectible_lookup_transport(cid, flags); +} + +bool vsock_remote_addr_equals(struct vsock_sock *vsk, + struct sockaddr_vm *other) +{ + struct vsock_remote_info *remote_info; + bool equals; + + rcu_read_lock(); + remote_info = vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return false; + } + + equals = vsock_addr_equals_addr(&remote_info->addr, other); + rcu_read_unlock(); + + return equals; +} +EXPORT_SYMBOL_GPL(vsock_remote_addr_equals); /* Each bound VSocket is stored in the bind hash table and each connected * VSocket is stored in the connected hash table. @@ -284,10 +427,16 @@ static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src, list_for_each_entry(vsk, vsock_connected_sockets(src, dst), connected_table) { - if (vsock_addr_equals_addr(src, &vsk->remote_addr) && + struct vsock_remote_info *remote_info; + + rcu_read_lock(); + remote_info = vsock_core_get_remote_info(vsk); + if (vsock_addr_equals_addr(src, &remote_info->addr) && dst->svm_port == vsk->local_addr.svm_port) { + rcu_read_unlock(); return sk_vsock(vsk); } + rcu_read_unlock(); } return NULL; @@ -300,17 +449,36 @@ static void vsock_insert_unbound(struct vsock_sock *vsk) spin_unlock_bh(&vsock_table_lock); } -void vsock_insert_connected(struct vsock_sock *vsk) +int vsock_insert_connected(struct vsock_sock *vsk) { - struct list_head *list = vsock_connected_sockets( - &vsk->remote_addr, &vsk->local_addr); + struct list_head *list; + struct vsock_remote_info *remote_info; + + rcu_read_lock(); + remote_info = vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return -EINVAL; + } + list = vsock_connected_sockets(&remote_info->addr, &vsk->local_addr); + rcu_read_unlock(); spin_lock_bh(&vsock_table_lock); __vsock_insert_connected(list, vsk); spin_unlock_bh(&vsock_table_lock); + + return 0; } EXPORT_SYMBOL_GPL(vsock_insert_connected); +void vsock_remove_dgram_bound(struct vsock_sock *vsk) +{ + spin_lock_bh(&vsock_dgram_table_lock); + if (__vsock_in_bound_table(vsk)) + __vsock_remove_bound(vsk); + spin_unlock_bh(&vsock_dgram_table_lock); +} + void vsock_remove_bound(struct vsock_sock *vsk) { spin_lock_bh(&vsock_table_lock); @@ -362,7 +530,10 @@ EXPORT_SYMBOL_GPL(vsock_find_connected_socket); void vsock_remove_sock(struct vsock_sock *vsk) { - vsock_remove_bound(vsk); + if (sock_type_connectible(sk_vsock(vsk)->sk_type)) + vsock_remove_bound(vsk); + else + vsock_remove_dgram_bound(vsk); vsock_remove_connected(vsk); } EXPORT_SYMBOL_GPL(vsock_remove_sock); @@ -378,7 +549,7 @@ void vsock_for_each_connected_socket(struct vsock_transport *transport, struct vsock_sock *vsk; list_for_each_entry(vsk, &vsock_connected_table[i], connected_table) { - if (vsk->transport != transport) + if (vsock_core_get_transport(vsk) != transport) continue; fn(sk_vsock(vsk)); @@ -444,59 +615,39 @@ static bool vsock_use_local_transport(unsigned int remote_cid) static void vsock_deassign_transport(struct vsock_sock *vsk) { - if (!vsk->transport) - return; - - vsk->transport->destruct(vsk); - module_put(vsk->transport->module); - vsk->transport = NULL; -} - -static const struct vsock_transport * -vsock_connectible_lookup_transport(unsigned int cid, __u8 flags) -{ - const struct vsock_transport *transport; + struct vsock_remote_info *remote_info; - if (vsock_use_local_transport(cid)) - transport = transport_local; - else if (cid <= VMADDR_CID_HOST || !transport_h2g || - (flags & VMADDR_FLAG_TO_HOST)) - transport = transport_g2h; - else - transport = transport_h2g; - - return transport; -} - -static const struct vsock_transport * -vsock_dgram_lookup_transport(unsigned int cid, __u8 flags) -{ - if (transport_dgram) - return transport_dgram; + remote_info = rcu_replace_pointer(vsk->remote_info, NULL, + lockdep_sock_is_held(sk_vsock(vsk))); + if (!remote_info) + return; - return vsock_connectible_lookup_transport(cid, flags); + remote_info->transport->destruct(vsk); + module_put(remote_info->transport->module); + kfree_rcu(remote_info, rcu); } /* Assign a transport to a socket and call the .init transport callback. * - * Note: for connection oriented socket this must be called when vsk->remote_addr - * is set (e.g. during the connect() or when a connection request on a listener - * socket is received). - * The vsk->remote_addr is used to decide which transport to use: + * The remote_addr is used to decide which transport to use: * - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if * g2h is not loaded, will use local transport; * - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport; * - remote CID > VMADDR_CID_HOST will use host->guest transport; */ -int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) +int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk, + struct sockaddr_vm *remote_addr) { const struct vsock_transport *new_transport; + struct vsock_remote_info *old_info; struct sock *sk = sk_vsock(vsk); - unsigned int remote_cid = vsk->remote_addr.svm_cid; + unsigned int remote_cid; __u8 remote_flags; int ret; + remote_cid = remote_addr->svm_cid; + /* If the packet is coming with the source and destination CIDs higher * than VMADDR_CID_HOST, then a vsock channel where all the packets are * forwarded to the host should be established. Then the host will @@ -506,10 +657,10 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) * the connect path the flag can be set by the user space application. */ if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST && - vsk->remote_addr.svm_cid > VMADDR_CID_HOST) - vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST; + remote_cid > VMADDR_CID_HOST) + remote_addr->svm_flags |= VMADDR_FLAG_TO_HOST; - remote_flags = vsk->remote_addr.svm_flags; + remote_flags = remote_addr->svm_flags; switch (sk->sk_type) { case SOCK_DGRAM: @@ -525,8 +676,9 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) return -ESOCKTNOSUPPORT; } - if (vsk->transport) { - if (vsk->transport == new_transport) + old_info = vsock_core_get_remote_info(vsk); + if (old_info && old_info->transport) { + if (old_info->transport == new_transport) return 0; /* transport->release() must be called with sock lock acquired. @@ -535,7 +687,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) * function is called on a new socket which is not assigned to * any transport. */ - vsk->transport->release(vsk); + old_info->transport->release(vsk); vsock_deassign_transport(vsk); } @@ -553,13 +705,18 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) } } - ret = new_transport->init(vsk, psk); + ret = vsock_set_remote_info(vsk, new_transport, remote_addr); if (ret) { module_put(new_transport->module); return ret; } - vsk->transport = new_transport; + ret = new_transport->init(vsk, psk); + if (ret) { + vsock_set_remote_info(vsk, NULL, NULL); + module_put(new_transport->module); + return ret; + } return 0; } @@ -616,12 +773,14 @@ static bool vsock_is_pending(struct sock *sk) static int vsock_send_shutdown(struct sock *sk, int mode) { + const struct vsock_transport *transport; struct vsock_sock *vsk = vsock_sk(sk); - if (!vsk->transport) + transport = vsock_core_get_transport(vsk); + if (!transport) return -ENODEV; - return vsk->transport->shutdown(vsk, mode); + return transport->shutdown(vsk, mode); } static void vsock_pending_work(struct work_struct *work) @@ -757,7 +916,10 @@ EXPORT_SYMBOL(vsock_bind_stream); static int vsock_bind_dgram(struct vsock_sock *vsk, struct sockaddr_vm *addr) { - if (!vsk->transport || !vsk->transport->dgram_bind) { + const struct vsock_transport *transport; + + transport = vsock_core_get_transport(vsk); + if (!transport || !transport->dgram_bind) { int retval; spin_lock_bh(&vsock_dgram_table_lock); retval = vsock_bind_common(vsk, addr, vsock_dgram_bind_table, @@ -767,7 +929,7 @@ static int vsock_bind_dgram(struct vsock_sock *vsk, return retval; } - return vsk->transport->dgram_bind(vsk, addr); + return transport->dgram_bind(vsk, addr); } static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) @@ -816,6 +978,7 @@ static struct sock *__vsock_create(struct net *net, unsigned short type, int kern) { + struct vsock_remote_info *remote_info; struct sock *sk; struct vsock_sock *psk; struct vsock_sock *vsk; @@ -835,7 +998,14 @@ static struct sock *__vsock_create(struct net *net, vsk = vsock_sk(sk); vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); - vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + + remote_info = kmalloc(sizeof(*remote_info), GFP_KERNEL); + if (!remote_info) { + sk_free(sk); + return NULL; + } + vsock_addr_init(&remote_info->addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + rcu_assign_pointer(vsk->remote_info, remote_info); sk->sk_destruct = vsock_sk_destruct; sk->sk_backlog_rcv = vsock_queue_rcv_skb; @@ -882,6 +1052,7 @@ static bool sock_type_connectible(u16 type) static void __vsock_release(struct sock *sk, int level) { if (sk) { + const struct vsock_transport *transport; struct sock *pending; struct vsock_sock *vsk; @@ -895,8 +1066,9 @@ static void __vsock_release(struct sock *sk, int level) */ lock_sock_nested(sk, level); - if (vsk->transport) - vsk->transport->release(vsk); + transport = vsock_core_get_transport(vsk); + if (transport) + transport->release(vsk); else if (sock_type_connectible(sk->sk_type)) vsock_remove_sock(vsk); @@ -926,8 +1098,6 @@ static void vsock_sk_destruct(struct sock *sk) * possibly register the address family with the kernel. */ vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); - vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); - put_cred(vsk->owner); } @@ -951,16 +1121,22 @@ EXPORT_SYMBOL_GPL(vsock_create_connected); s64 vsock_stream_has_data(struct vsock_sock *vsk) { - return vsk->transport->stream_has_data(vsk); + const struct vsock_transport *transport; + + transport = vsock_core_get_transport(vsk); + + return transport->stream_has_data(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_data); s64 vsock_connectible_has_data(struct vsock_sock *vsk) { + const struct vsock_transport *transport; struct sock *sk = sk_vsock(vsk); + transport = vsock_core_get_transport(vsk); if (sk->sk_type == SOCK_SEQPACKET) - return vsk->transport->seqpacket_has_data(vsk); + return transport->seqpacket_has_data(vsk); else return vsock_stream_has_data(vsk); } @@ -968,7 +1144,10 @@ EXPORT_SYMBOL_GPL(vsock_connectible_has_data); s64 vsock_stream_has_space(struct vsock_sock *vsk) { - return vsk->transport->stream_has_space(vsk); + const struct vsock_transport *transport; + + transport = vsock_core_get_transport(vsk); + return transport->stream_has_space(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_space); @@ -1017,6 +1196,7 @@ static int vsock_getname(struct socket *sock, struct sock *sk; struct vsock_sock *vsk; struct sockaddr_vm *vm_addr; + struct vsock_remote_info *rcu_ptr; sk = sock->sk; vsk = vsock_sk(sk); @@ -1025,11 +1205,17 @@ static int vsock_getname(struct socket *sock, lock_sock(sk); if (peer) { + rcu_read_lock(); if (sock->state != SS_CONNECTED) { err = -ENOTCONN; goto out; } - vm_addr = &vsk->remote_addr; + rcu_ptr = vsock_core_get_remote_info(vsk); + if (!rcu_ptr) { + err = -EINVAL; + goto out; + } + vm_addr = &rcu_ptr->addr; } else { vm_addr = &vsk->local_addr; } @@ -1049,6 +1235,8 @@ static int vsock_getname(struct socket *sock, err = sizeof(*vm_addr); out: + if (peer) + rcu_read_unlock(); release_sock(sk); return err; } @@ -1153,7 +1341,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock, lock_sock(sk); - transport = vsk->transport; + transport = vsock_core_get_transport(vsk); /* Listening sockets that have connections in their accept * queue can be read. @@ -1224,9 +1412,11 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock, static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor) { + const struct vsock_transport *transport; struct vsock_sock *vsk = vsock_sk(sk); - return vsk->transport->read_skb(vsk, read_actor); + transport = vsock_core_get_transport(vsk); + return transport->read_skb(vsk, read_actor); } static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, @@ -1235,7 +1425,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, int err; struct sock *sk; struct vsock_sock *vsk; - struct sockaddr_vm *remote_addr; + struct sockaddr_vm stack_addr, *remote_addr; const struct vsock_transport *transport; if (msg->msg_flags & MSG_OOB) @@ -1246,7 +1436,23 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, sk = sock->sk; vsk = vsock_sk(sk); - lock_sock(sk); + /* If auto-binding is required, acquire the slock to avoid potential + * race conditions. Otherwise, do not acquire the lock. + * + * We know that the first check of local_addr is racy (indicated by + * data_race()). By acquiring the lock and then subsequently checking + * again if local_addr is bound (inside vsock_auto_bind()), we can + * ensure there are no real data races. + * + * This technique is borrowed by inet_send_prepare(). + */ + if (data_race(!vsock_addr_bound(&vsk->local_addr))) { + lock_sock(sk); + err = vsock_auto_bind(vsk); + release_sock(sk); + if (err) + return err; + } /* If the provided message contains an address, use that. Otherwise * fall back on the socket's remote handle (if it has been connected). @@ -1256,6 +1462,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, &remote_addr) == 0) { transport = vsock_dgram_lookup_transport(remote_addr->svm_cid, remote_addr->svm_flags); + if (!transport) { err = -EINVAL; goto out; @@ -1286,18 +1493,39 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, goto out; } - err = transport->dgram_enqueue(vsk, remote_addr, msg, len); + err = transport->dgram_enqueue(transport, vsk, remote_addr, msg, len); module_put(transport->module); } else if (sock->state == SS_CONNECTED) { - remote_addr = &vsk->remote_addr; - transport = vsk->transport; + struct vsock_remote_info *remote_info; + const struct vsock_transport *transport; - err = vsock_auto_bind(vsk); - if (err) + rcu_read_lock(); + remote_info = vsock_core_get_remote_info(vsk); + if (!remote_info) { + err = -EINVAL; + rcu_read_unlock(); goto out; + } - if (remote_addr->svm_cid == VMADDR_CID_ANY) + transport = remote_info->transport; + memcpy(&stack_addr, &remote_info->addr, sizeof(stack_addr)); + rcu_read_unlock(); + + remote_addr = &stack_addr; + + if (remote_addr->svm_cid == VMADDR_CID_ANY) { remote_addr->svm_cid = transport->get_local_cid(); + lock_sock(sk_vsock(vsk)); + /* Even though the CID has changed, We do not have to + * look up the transport again because the local CID + * will never resolve to a different transport. + */ + err = vsock_set_remote_info(vsk, transport, remote_addr); + release_sock(sk_vsock(vsk)); + + if (err) + goto out; + } /* XXX Should connect() or this function ensure remote_addr is * bound? @@ -1313,14 +1541,13 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, goto out; } - err = transport->dgram_enqueue(vsk, remote_addr, msg, len); + err = transport->dgram_enqueue(transport, vsk, &stack_addr, msg, len); } else { err = -EINVAL; goto out; } out: - release_sock(sk); return err; } @@ -1331,18 +1558,22 @@ static int vsock_dgram_connect(struct socket *sock, struct sock *sk; struct vsock_sock *vsk; struct sockaddr_vm *remote_addr; + const struct vsock_transport *transport; sk = sock->sk; vsk = vsock_sk(sk); err = vsock_addr_cast(addr, addr_len, &remote_addr); if (err == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) { + struct sockaddr_vm addr_any; + lock_sock(sk); - vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, - VMADDR_PORT_ANY); + vsock_addr_init(&addr_any, VMADDR_CID_ANY, VMADDR_PORT_ANY); + err = vsock_set_remote_info(vsk, vsock_core_get_transport(vsk), + &addr_any); sock->state = SS_UNCONNECTED; release_sock(sk); - return 0; + return err; } else if (err != 0) return -EINVAL; @@ -1352,14 +1583,13 @@ static int vsock_dgram_connect(struct socket *sock, if (err) goto out; - memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr)); - - err = vsock_assign_transport(vsk, NULL); + err = vsock_assign_transport(vsk, NULL, remote_addr); if (err) goto out; - if (!vsk->transport->dgram_allow(remote_addr->svm_cid, - remote_addr->svm_port)) { + transport = vsock_core_get_transport(vsk); + if (!transport->dgram_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { err = -EINVAL; goto out; } @@ -1406,7 +1636,9 @@ int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, if (flags & MSG_OOB || flags & MSG_ERRQUEUE) return -EOPNOTSUPP; - transport = vsk->transport; + rcu_read_lock(); + transport = vsock_core_get_transport(vsk); + rcu_read_unlock(); /* Retrieve the head sk_buff from the socket's receive queue. */ err = 0; @@ -1474,7 +1706,7 @@ static const struct proto_ops vsock_dgram_ops = { static int vsock_transport_cancel_pkt(struct vsock_sock *vsk) { - const struct vsock_transport *transport = vsk->transport; + const struct vsock_transport *transport = vsock_core_get_transport(vsk); if (!transport || !transport->cancel_pkt) return -EOPNOTSUPP; @@ -1511,6 +1743,7 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr, struct sock *sk; struct vsock_sock *vsk; const struct vsock_transport *transport; + struct vsock_remote_info *remote_info; struct sockaddr_vm *remote_addr; long timeout; DEFINE_WAIT(wait); @@ -1548,14 +1781,20 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr, } /* Set the remote address that we are connecting to. */ - memcpy(&vsk->remote_addr, remote_addr, - sizeof(vsk->remote_addr)); - - err = vsock_assign_transport(vsk, NULL); + err = vsock_assign_transport(vsk, NULL, remote_addr); if (err) goto out; - transport = vsk->transport; + rcu_read_lock(); + remote_info = vsock_core_get_remote_info(vsk); + if (!remote_info) { + err = -EINVAL; + rcu_read_unlock(); + goto out; + } + + transport = remote_info->transport; + rcu_read_unlock(); /* The hypervisor and well-known contexts do not have socket * endpoints. @@ -1819,7 +2058,7 @@ static int vsock_connectible_setsockopt(struct socket *sock, lock_sock(sk); - transport = vsk->transport; + transport = vsock_core_get_transport(vsk); switch (optname) { case SO_VM_SOCKETS_BUFFER_SIZE: @@ -1957,7 +2196,7 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg, lock_sock(sk); - transport = vsk->transport; + transport = vsock_core_get_transport(vsk); /* Callers should not provide a destination with connection oriented * sockets. @@ -1980,7 +2219,7 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg, goto out; } - if (!vsock_addr_bound(&vsk->remote_addr)) { + if (!vsock_remote_addr_bound(vsk)) { err = -EDESTADDRREQ; goto out; } @@ -2101,7 +2340,7 @@ static int vsock_connectible_wait_data(struct sock *sk, vsk = vsock_sk(sk); err = 0; - transport = vsk->transport; + transport = vsock_core_get_transport(vsk); while (1) { prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE); @@ -2169,7 +2408,7 @@ static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg, DEFINE_WAIT(wait); vsk = vsock_sk(sk); - transport = vsk->transport; + transport = vsock_core_get_transport(vsk); /* We must not copy less than target bytes into the user's buffer * before returning successfully, so we wait for the consume queue to @@ -2245,7 +2484,7 @@ static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg, DEFINE_WAIT(wait); vsk = vsock_sk(sk); - transport = vsk->transport; + transport = vsock_core_get_transport(vsk); timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); @@ -2302,7 +2541,7 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, lock_sock(sk); - transport = vsk->transport; + transport = vsock_core_get_transport(vsk); if (!transport || sk->sk_state != TCP_ESTABLISHED) { /* Recvmsg is supposed to return 0 if a peer performs an @@ -2369,7 +2608,7 @@ static int vsock_set_rcvlowat(struct sock *sk, int val) if (val > vsk->buffer_size) return -EINVAL; - transport = vsk->transport; + transport = vsock_core_get_transport(vsk); if (transport && transport->set_rcvlowat) return transport->set_rcvlowat(vsk, val); @@ -2459,7 +2698,10 @@ static int vsock_create(struct net *net, struct socket *sock, vsk = vsock_sk(sk); if (sock->type == SOCK_DGRAM) { - ret = vsock_assign_transport(vsk, NULL); + struct sockaddr_vm remote_addr; + + vsock_addr_init(&remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + ret = vsock_assign_transport(vsk, NULL, &remote_addr); if (ret < 0) { sock_put(sk); return ret; @@ -2581,7 +2823,18 @@ static void __exit vsock_exit(void) const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk) { - return vsk->transport; + const struct vsock_transport *transport; + struct vsock_remote_info *remote_info; + + rcu_read_lock(); + remote_info = vsock_core_get_remote_info(vsk); + if (!remote_info) { + rcu_read_unlock(); + return NULL; + } + transport = remote_info->transport; + rcu_read_unlock(); + return transport; } EXPORT_SYMBOL_GPL(vsock_core_get_transport); diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c index a2823b1c5e28..f843bae86b32 100644 --- a/net/vmw_vsock/diag.c +++ b/net/vmw_vsock/diag.c @@ -15,8 +15,14 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, u32 portid, u32 seq, u32 flags) { struct vsock_sock *vsk = vsock_sk(sk); + struct sockaddr_vm remote_addr; struct vsock_diag_msg *rep; struct nlmsghdr *nlh; + int err; + + err = vsock_remote_addr_copy(vsk, &remote_addr); + if (err < 0) + return err; nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep), flags); @@ -36,8 +42,8 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, rep->vdiag_shutdown = sk->sk_shutdown; rep->vdiag_src_cid = vsk->local_addr.svm_cid; rep->vdiag_src_port = vsk->local_addr.svm_port; - rep->vdiag_dst_cid = vsk->remote_addr.svm_cid; - rep->vdiag_dst_port = vsk->remote_addr.svm_port; + rep->vdiag_dst_cid = remote_addr.svm_cid; + rep->vdiag_dst_port = remote_addr.svm_port; rep->vdiag_ino = sock_i_ino(sk); sock_diag_save_cookie(sk, rep->vdiag_cookie); diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index c00bc5da769a..84e8c64b3365 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -323,6 +323,8 @@ static void hvs_open_connection(struct vmbus_channel *chan) goto out; if (conn_from_host) { + struct sockaddr_vm remote_addr; + if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) goto out; @@ -336,10 +338,9 @@ static void hvs_open_connection(struct vmbus_channel *chan) hvs_addr_init(&vnew->local_addr, if_type); /* Remote peer is always the host */ - vsock_addr_init(&vnew->remote_addr, - VMADDR_CID_HOST, VMADDR_PORT_ANY); - vnew->remote_addr.svm_port = get_port_by_srv_id(if_instance); - ret = vsock_assign_transport(vnew, vsock_sk(sk)); + vsock_addr_init(&remote_addr, VMADDR_CID_HOST, get_port_by_srv_id(if_instance)); + + ret = vsock_assign_transport(vnew, vsock_sk(sk), &remote_addr); /* Transport assigned (looking at remote_addr) must be the * same where we received the request. */ @@ -459,13 +460,18 @@ static int hvs_connect(struct vsock_sock *vsk) { union hvs_service_id vm, host; struct hvsock *h = vsk->trans; + int err; vm.srv_id = srv_id_template; vm.svm_port = vsk->local_addr.svm_port; h->vm_srv_id = vm.srv_id; host.srv_id = srv_id_template; - host.svm_port = vsk->remote_addr.svm_port; + + err = vsock_remote_addr_port(vsk, &host.svm_port); + if (err < 0) + return err; + h->host_srv_id = host.srv_id; return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id); @@ -566,7 +572,8 @@ static int hvs_dgram_get_length(struct sk_buff *skb, size_t *len) return -EOPNOTSUPP; } -static int hvs_dgram_enqueue(struct vsock_sock *vsk, +static int hvs_dgram_enqueue(const struct vsock_transport *transport, + struct vsock_sock *vsk, struct sockaddr_vm *remote, struct msghdr *msg, size_t dgram_len) { @@ -866,7 +873,13 @@ static struct vsock_transport hvs_transport = { static bool hvs_check_transport(struct vsock_sock *vsk) { - return vsk->transport == &hvs_transport; + bool ret; + + rcu_read_lock(); + ret = vsock_core_get_transport(vsk) == &hvs_transport; + rcu_read_unlock(); + + return ret; } static int hvs_probe(struct hv_device *hdev, diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index ab4af21c4f3f..09d35c488902 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -258,8 +258,9 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, src_cid = t_ops->transport.get_local_cid(); src_port = vsk->local_addr.svm_port; if (!info->remote_cid) { - dst_cid = vsk->remote_addr.svm_cid; - dst_port = vsk->remote_addr.svm_port; + ret = vsock_remote_addr_cid_port(vsk, &dst_cid, &dst_port); + if (ret < 0) + return ret; } else { dst_cid = info->remote_cid; dst_port = info->remote_port; @@ -877,12 +878,14 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) EXPORT_SYMBOL_GPL(virtio_transport_shutdown); int -virtio_transport_dgram_enqueue(struct vsock_sock *vsk, +virtio_transport_dgram_enqueue(const struct vsock_transport *transport, + struct vsock_sock *vsk, struct sockaddr_vm *remote_addr, struct msghdr *msg, size_t dgram_len) { - const struct virtio_transport *t_ops; + const struct virtio_transport *t_ops = + (const struct virtio_transport *)transport; struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_RW, .msg = msg, @@ -896,7 +899,6 @@ virtio_transport_dgram_enqueue(struct vsock_sock *vsk, if (dgram_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) return -EMSGSIZE; - t_ops = virtio_transport_get_ops(vsk); src_cid = t_ops->transport.get_local_cid(); src_port = vsk->local_addr.svm_port; @@ -1120,7 +1122,9 @@ virtio_transport_recv_connecting(struct sock *sk, case VIRTIO_VSOCK_OP_RESPONSE: sk->sk_state = TCP_ESTABLISHED; sk->sk_socket->state = SS_CONNECTED; - vsock_insert_connected(vsk); + err = vsock_insert_connected(vsk); + if (err) + goto destroy; sk->sk_state_change(sk); break; case VIRTIO_VSOCK_OP_INVALID: @@ -1326,6 +1330,7 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); struct vsock_sock *vsk = vsock_sk(sk); struct vsock_sock *vchild; + struct sockaddr_vm child_remote; struct sock *child; int ret; @@ -1354,14 +1359,13 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, vchild = vsock_sk(child); vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid), le32_to_cpu(hdr->dst_port)); - vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid), + vsock_addr_init(&child_remote, le64_to_cpu(hdr->src_cid), le32_to_cpu(hdr->src_port)); - - ret = vsock_assign_transport(vchild, vsk); + ret = vsock_assign_transport(vchild, vsk, &child_remote); /* Transport assigned (looking at remote_addr) must be the same * where we received the request. */ - if (ret || vchild->transport != &t->transport) { + if (ret || vsock_core_get_transport(vchild) != &t->transport) { release_sock(child); virtio_transport_reset_no_sock(t, skb); sock_put(child); @@ -1371,7 +1375,13 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, if (virtio_transport_space_update(child, skb)) child->sk_write_space(child); - vsock_insert_connected(vchild); + ret = vsock_insert_connected(vchild); + if (ret) { + release_sock(child); + virtio_transport_reset_no_sock(t, skb); + sock_put(child); + return ret; + } vsock_enqueue_accept(sk, child); virtio_transport_send_response(vchild, skb); diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index b6a51afb74b8..b9ba6209e8fc 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -283,18 +283,25 @@ vmci_transport_send_control_pkt(struct sock *sk, u16 proto, struct vmci_handle handle) { + struct sockaddr_vm addr_stack; + struct sockaddr_vm *remote_addr = &addr_stack; struct vsock_sock *vsk; + int err; vsk = vsock_sk(sk); if (!vsock_addr_bound(&vsk->local_addr)) return -EINVAL; - if (!vsock_addr_bound(&vsk->remote_addr)) + if (!vsock_remote_addr_bound(vsk)) return -EINVAL; + err = vsock_remote_addr_copy(vsk, remote_addr); + if (err < 0) + return err; + return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, - &vsk->remote_addr, + remote_addr, type, size, mode, wait, proto, handle); } @@ -317,6 +324,7 @@ static int vmci_transport_send_reset(struct sock *sk, struct sockaddr_vm *dst_ptr; struct sockaddr_vm dst; struct vsock_sock *vsk; + int err; if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) return 0; @@ -326,13 +334,16 @@ static int vmci_transport_send_reset(struct sock *sk, if (!vsock_addr_bound(&vsk->local_addr)) return -EINVAL; - if (vsock_addr_bound(&vsk->remote_addr)) { - dst_ptr = &vsk->remote_addr; + if (vsock_remote_addr_bound(vsk)) { + err = vsock_remote_addr_copy(vsk, &dst); + if (err < 0) + return err; } else { vsock_addr_init(&dst, pkt->dg.src.context, pkt->src_port); - dst_ptr = &dst; } + dst_ptr = &dst; + return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, dst_ptr, VMCI_TRANSPORT_PACKET_TYPE_RST, 0, 0, NULL, VSOCK_PROTO_INVALID, @@ -490,7 +501,7 @@ static struct sock *vmci_transport_get_pending( list_for_each_entry(vpending, &vlistener->pending_links, pending_links) { - if (vsock_addr_equals_addr(&src, &vpending->remote_addr) && + if (vsock_remote_addr_equals(vpending, &src) && pkt->dst_port == vpending->local_addr.svm_port) { pending = sk_vsock(vpending); sock_hold(pending); @@ -940,6 +951,7 @@ static void vmci_transport_recv_pkt_work(struct work_struct *work) static int vmci_transport_recv_listen(struct sock *sk, struct vmci_transport_packet *pkt) { + struct sockaddr_vm remote_addr; struct sock *pending; struct vsock_sock *vpending; int err; @@ -1015,10 +1027,10 @@ static int vmci_transport_recv_listen(struct sock *sk, vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context, pkt->dst_port); - vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context, - pkt->src_port); - err = vsock_assign_transport(vpending, vsock_sk(sk)); + vsock_addr_init(&remote_addr, pkt->dg.src.context, pkt->src_port); + + err = vsock_assign_transport(vpending, vsock_sk(sk), &remote_addr); /* Transport assigned (looking at remote_addr) must be the same * where we received the request. */ @@ -1133,6 +1145,7 @@ vmci_transport_recv_connecting_server(struct sock *listener, { struct vsock_sock *vpending; struct vmci_handle handle; + unsigned int vpending_remote_cid; struct vmci_qp *qpair; bool is_local; u32 flags; @@ -1189,8 +1202,13 @@ vmci_transport_recv_connecting_server(struct sock *listener, /* vpending->local_addr always has a context id so we do not need to * worry about VMADDR_CID_ANY in this case. */ - is_local = - vpending->remote_addr.svm_cid == vpending->local_addr.svm_cid; + err = vsock_remote_addr_cid(vpending, &vpending_remote_cid); + if (err < 0) { + skerr = EPROTO; + goto destroy; + } + + is_local = vpending_remote_cid == vpending->local_addr.svm_cid; flags = VMCI_QPFLAG_ATTACH_ONLY; flags |= is_local ? VMCI_QPFLAG_LOCAL : 0; @@ -1203,7 +1221,7 @@ vmci_transport_recv_connecting_server(struct sock *listener, flags, vmci_transport_is_trusted( vpending, - vpending->remote_addr.svm_cid)); + vpending_remote_cid)); if (err < 0) { vmci_transport_send_reset(pending, pkt); skerr = -err; @@ -1277,6 +1295,8 @@ static int vmci_transport_recv_connecting_client(struct sock *sk, struct vmci_transport_packet *pkt) { + struct vsock_remote_info *remote_info; + struct sockaddr_vm *remote_addr; struct vsock_sock *vsk; int err; int skerr; @@ -1306,9 +1326,20 @@ vmci_transport_recv_connecting_client(struct sock *sk, break; case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE: case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2: + rcu_read_lock(); + remote_info = vsock_core_get_remote_info(vsk); + if (!remote_info) { + skerr = EPROTO; + err = -EINVAL; + rcu_read_unlock(); + goto destroy; + } + + remote_addr = &remote_info->addr; + if (pkt->u.size == 0 - || pkt->dg.src.context != vsk->remote_addr.svm_cid - || pkt->src_port != vsk->remote_addr.svm_port + || pkt->dg.src.context != remote_addr->svm_cid + || pkt->src_port != remote_addr->svm_port || !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle) || vmci_trans(vsk)->qpair || vmci_trans(vsk)->produce_size != 0 @@ -1316,9 +1347,10 @@ vmci_transport_recv_connecting_client(struct sock *sk, || vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) { skerr = EPROTO; err = -EINVAL; - + rcu_read_unlock(); goto destroy; } + rcu_read_unlock(); err = vmci_transport_recv_connecting_client_negotiate(sk, pkt); if (err) { @@ -1379,6 +1411,7 @@ static int vmci_transport_recv_connecting_client_negotiate( int err; struct vsock_sock *vsk; struct vmci_handle handle; + unsigned int remote_cid; struct vmci_qp *qpair; u32 detach_sub_id; bool is_local; @@ -1449,19 +1482,23 @@ static int vmci_transport_recv_connecting_client_negotiate( /* Make VMCI select the handle for us. */ handle = VMCI_INVALID_HANDLE; - is_local = vsk->remote_addr.svm_cid == vsk->local_addr.svm_cid; + + err = vsock_remote_addr_cid(vsk, &remote_cid); + if (err < 0) + goto destroy; + + is_local = remote_cid == vsk->local_addr.svm_cid; flags = is_local ? VMCI_QPFLAG_LOCAL : 0; err = vmci_transport_queue_pair_alloc(&qpair, &handle, pkt->u.size, pkt->u.size, - vsk->remote_addr.svm_cid, + remote_cid, flags, vmci_transport_is_trusted( vsk, - vsk-> - remote_addr.svm_cid)); + remote_cid)); if (err < 0) goto destroy; @@ -1692,6 +1729,7 @@ static int vmci_transport_dgram_bind(struct vsock_sock *vsk, } static int vmci_transport_dgram_enqueue( + const struct vsock_transport *transport, struct vsock_sock *vsk, struct sockaddr_vm *remote_addr, struct msghdr *msg, @@ -2052,7 +2090,13 @@ static struct vsock_transport vmci_transport = { static bool vmci_check_transport(struct vsock_sock *vsk) { - return vsk->transport == &vmci_transport; + bool retval; + + rcu_read_lock(); + retval = vsock_core_get_transport(vsk) == &vmci_transport; + rcu_read_unlock(); + + return retval; } static void vmci_vsock_transport_cb(bool is_host) diff --git a/net/vmw_vsock/vsock_bpf.c b/net/vmw_vsock/vsock_bpf.c index a3c97546ab84..4d811c9cdf6e 100644 --- a/net/vmw_vsock/vsock_bpf.c +++ b/net/vmw_vsock/vsock_bpf.c @@ -148,6 +148,7 @@ static void vsock_bpf_check_needs_rebuild(struct proto *ops) int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) { + const struct vsock_transport *transport; struct vsock_sock *vsk; if (restore) { @@ -157,10 +158,15 @@ int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore } vsk = vsock_sk(sk); - if (!vsk->transport) + + rcu_read_lock(); + transport = vsock_core_get_transport(vsk); + rcu_read_unlock(); + + if (!transport) return -ENODEV; - if (!vsk->transport->read_skb) + if (!transport->read_skb) return -EOPNOTSUPP; vsock_bpf_check_needs_rebuild(psock->sk_proto); -- 2.30.2