Throughput metrics for single-threaded SOCK_DGRAM and
single/multi-threaded SOCK_STREAM showed no statistically signficant
throughput changes (lowest p-value reaching 0.27), with the range of the
mean difference ranging between -5% to +1%.
Signed-off-by: Bobby Eshleman <bobby.eshleman@xxxxxxxxxxxxx>
---
drivers/vhost/vsock.c | 12 +-
include/net/af_vsock.h | 19 ++-
net/vmw_vsock/af_vsock.c | 261 ++++++++++++++++++++++++++++----
net/vmw_vsock/diag.c | 10 +-
net/vmw_vsock/hyperv_transport.c | 15 +-
net/vmw_vsock/virtio_transport_common.c | 22 ++-
net/vmw_vsock/vmci_transport.c | 70 ++++++---
7 files changed, 344 insertions(+), 65 deletions(-)
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 028cf079225e..da105cb856ac 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -296,13 +296,17 @@ static int
vhost_transport_cancel_pkt(struct vsock_sock *vsk)
{
struct vhost_vsock *vsock;
+ unsigned int cid;
int cnt = 0;
int ret = -ENODEV;
rcu_read_lock();
+ ret = vsock_remote_addr_cid(vsk, &cid);
+ if (ret < 0)
+ goto out;
/* Find the vhost_vsock according to guest context id */
- vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
+ vsock = vhost_vsock_get(cid);
if (!vsock)
goto out;
@@ -686,6 +690,10 @@ static void vhost_vsock_flush(struct vhost_vsock *vsock)
static void vhost_vsock_reset_orphans(struct sock *sk)
{
struct vsock_sock *vsk = vsock_sk(sk);
+ unsigned int cid;
+
+ if (vsock_remote_addr_cid(vsk, &cid) < 0)
+ return;
/* vmci_transport.c doesn't take sk_lock here either. At least we're
* under vsock_table_lock so the sock cannot disappear while we're
@@ -693,7 +701,7 @@ static void vhost_vsock_reset_orphans(struct sock *sk)
*/
/* If the peer is still valid, no need to reset connection */
- if (vhost_vsock_get(vsk->remote_addr.svm_cid))
+ if (vhost_vsock_get(cid))
return;
/* If the close timeout is pending, let it expire. This avoids races
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index 57af28fede19..c02fd6ad0047 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -25,12 +25,17 @@ extern spinlock_t vsock_table_lock;
#define vsock_sk(__sk) ((struct vsock_sock *)__sk)
#define sk_vsock(__vsk) (&(__vsk)->sk)
+struct sockaddr_vm_rcu {
+ struct sockaddr_vm addr;
+ struct rcu_head rcu;
+};
+
struct vsock_sock {
/* sk must be the first member. */
struct sock sk;
const struct vsock_transport *transport;
struct sockaddr_vm local_addr;
- struct sockaddr_vm remote_addr;
+ struct sockaddr_vm_rcu * __rcu remote_addr;
/* Links for the global tables of bound and connected sockets. */
struct list_head bound_table;
struct list_head connected_table;
@@ -206,7 +211,7 @@ void vsock_release_pending(struct sock *pending);
void vsock_add_pending(struct sock *listener, struct sock *pending);
void vsock_remove_pending(struct sock *listener, struct sock *pending);
void vsock_enqueue_accept(struct sock *listener, struct sock *connected);
-void vsock_insert_connected(struct vsock_sock *vsk);
+int vsock_insert_connected(struct vsock_sock *vsk);
void vsock_remove_bound(struct vsock_sock *vsk);
void vsock_remove_connected(struct vsock_sock *vsk);
struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
@@ -244,4 +249,14 @@ static inline void __init vsock_bpf_build_proto(void)
{}
#endif
+/* RCU-protected remote addr helpers */
+int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid);
+int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port);
+int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid,
+ unsigned int *port);
+int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *dest);
+bool vsock_remote_addr_bound(struct vsock_sock *vsk);
+bool vsock_remote_addr_equals(struct vsock_sock *vsk, struct sockaddr_vm *other);
+int vsock_remote_addr_update_cid_port(struct vsock_sock *vsk, u32 cid, u32 port);
+
#endif /* __AF_VSOCK_H__ */
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 46b3f35e3adc..93b4abbf20b4 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -145,6 +145,139 @@ static const struct vsock_transport *transport_local;
static DEFINE_MUTEX(vsock_register_mutex);
/**** UTILS ****/
+bool vsock_remote_addr_bound(struct vsock_sock *vsk)
+{
+ struct sockaddr_vm_rcu *remote_addr;
+ bool ret;
+
+ rcu_read_lock();
+ remote_addr = rcu_dereference(vsk->remote_addr);
+ if (!remote_addr) {
+ rcu_read_unlock();
+ return false;
+ }
+
+ ret = vsock_addr_bound(&remote_addr->addr);
+ rcu_read_unlock();
+
+ return ret;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_bound);
+
+int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *dest)
+{
+ struct sockaddr_vm_rcu *remote_addr;
+
+ rcu_read_lock();
+ remote_addr = rcu_dereference(vsk->remote_addr);
+ if (!remote_addr) {
+ rcu_read_unlock();
+ return -EINVAL;
+ }
+ memcpy(dest, &remote_addr->addr, sizeof(*dest));
+ rcu_read_unlock();
+
+ return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_copy);
+
+int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid)
+{
+ return vsock_remote_addr_cid_port(vsk, cid, NULL);
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_cid);
+
+int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port)
+{
+ return vsock_remote_addr_cid_port(vsk, NULL, port);
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_port);
+
+int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid,
+ unsigned int *port)
+{
+ struct sockaddr_vm_rcu *remote_addr;
+
+ rcu_read_lock();
+ remote_addr = rcu_dereference(vsk->remote_addr);
+ if (!remote_addr) {
+ rcu_read_unlock();
+ return -EINVAL;
+ }
+
+ if (cid)
+ *cid = remote_addr->addr.svm_cid;
+ if (port)
+ *port = remote_addr->addr.svm_port;
+
+ rcu_read_unlock();
+ return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_cid_port);
+
+/* The socket lock must be held by the caller */
+int vsock_remote_addr_update_cid_port(struct vsock_sock *vsk, u32 cid, u32 port)
+{
+ struct sockaddr_vm_rcu *old, *new;
+
+ new = kmalloc(sizeof(*new), GFP_KERNEL);
+ if (!new)
+ return -ENOMEM;
+
+ rcu_read_lock();
+ old = rcu_dereference(vsk->remote_addr);
+ if (!old) {
+ kfree(new);
+ return -EINVAL;
+ }
+ memcpy(&new->addr, &old->addr, sizeof(new->addr));
+ rcu_read_unlock();
+
+ new->addr.svm_cid = cid;
+ new->addr.svm_port = port;
+
+ old = rcu_replace_pointer(vsk->remote_addr, new, lockdep_sock_is_held(sk_vsock(vsk)));
+ kfree_rcu(old, rcu);
+
+ return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_update_cid_port);
+
+/* The socket lock must be held by the caller */
+int vsock_remote_addr_update(struct vsock_sock *vsk, struct sockaddr_vm *src)
+{
+ struct sockaddr_vm_rcu *old, *new;
+
+ new = kmalloc(sizeof(*new), GFP_KERNEL);
+ if (!new)
+ return -ENOMEM;
+
+ memcpy(&new->addr, src, sizeof(new->addr));
+ old = rcu_replace_pointer(vsk->remote_addr, new, lockdep_sock_is_held(sk_vsock(vsk)));
+ kfree_rcu(old, rcu);
+
+ return 0;
+}
+
+bool vsock_remote_addr_equals(struct vsock_sock *vsk,
+ struct sockaddr_vm *other)
+{
+ struct sockaddr_vm_rcu *remote_addr;
+ bool equals;
+
+ rcu_read_lock();
+ remote_addr = rcu_dereference(vsk->remote_addr);
+ if (!remote_addr) {
+ rcu_read_unlock();
+ return false;
+ }
+
+ equals = vsock_addr_equals_addr(&remote_addr->addr, other);
+ rcu_read_unlock();
+
+ return equals;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_equals);
/* Each bound VSocket is stored in the bind hash table and each connected
* VSocket is stored in the connected hash table.
@@ -254,10 +387,16 @@ static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
connected_table) {
- if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
+ struct sockaddr_vm_rcu *remote_addr;
+
+ rcu_read_lock();
+ remote_addr = rcu_dereference(vsk->remote_addr);
+ if (vsock_addr_equals_addr(src, &remote_addr->addr) &&
dst->svm_port == vsk->local_addr.svm_port) {
+ rcu_read_unlock();
return sk_vsock(vsk);
}
+ rcu_read_unlock();
}
return NULL;
@@ -270,14 +409,25 @@ static void vsock_insert_unbound(struct vsock_sock *vsk)
spin_unlock_bh(&vsock_table_lock);
}
-void vsock_insert_connected(struct vsock_sock *vsk)
+int vsock_insert_connected(struct vsock_sock *vsk)
{
- struct list_head *list = vsock_connected_sockets(
- &vsk->remote_addr, &vsk->local_addr);
+ struct list_head *list;
+ struct sockaddr_vm_rcu *remote_addr;
+
+ rcu_read_lock();
+ remote_addr = rcu_dereference(vsk->remote_addr);
+ if (!remote_addr) {
+ rcu_read_unlock();
+ return -EINVAL;
+ }
+ list = vsock_connected_sockets(&remote_addr->addr, &vsk->local_addr);
+ rcu_read_unlock();
spin_lock_bh(&vsock_table_lock);
__vsock_insert_connected(list, vsk);
spin_unlock_bh(&vsock_table_lock);
+
+ return 0;
}
EXPORT_SYMBOL_GPL(vsock_insert_connected);
@@ -438,10 +588,17 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
{
const struct vsock_transport *new_transport;
struct sock *sk = sk_vsock(vsk);
- unsigned int remote_cid = vsk->remote_addr.svm_cid;
+ struct sockaddr_vm remote_addr;
+ unsigned int remote_cid;
__u8 remote_flags;
int ret;
+ ret = vsock_remote_addr_copy(vsk, &remote_addr);
+ if (ret < 0)
+ return ret;
+
+ remote_cid = remote_addr.svm_cid;
+
/* If the packet is coming with the source and destination CIDs higher
* than VMADDR_CID_HOST, then a vsock channel where all the packets are
* forwarded to the host should be established. Then the host will
@@ -451,10 +608,15 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
* the connect path the flag can be set by the user space application.
*/
if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST &&
- vsk->remote_addr.svm_cid > VMADDR_CID_HOST)
- vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST;
+ remote_addr.svm_cid > VMADDR_CID_HOST) {
+ remote_addr.svm_flags |= VMADDR_CID_HOST;
+
+ ret = vsock_remote_addr_update(vsk, &remote_addr);
+ if (ret < 0)
+ return ret;
+ }
- remote_flags = vsk->remote_addr.svm_flags;
+ remote_flags = remote_addr.svm_flags;
switch (sk->sk_type) {
case SOCK_DGRAM:
@@ -742,6 +904,7 @@ static struct sock *__vsock_create(struct net *net,
unsigned short type,
int kern)
{
+ struct sockaddr_vm *remote_addr;
struct sock *sk;
struct vsock_sock *psk;
struct vsock_sock *vsk;
@@ -761,7 +924,14 @@ static struct sock *__vsock_create(struct net *net,
vsk = vsock_sk(sk);
vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
- vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+
+ remote_addr = kmalloc(sizeof(*remote_addr), GFP_KERNEL);
+ if (!remote_addr) {
+ sk_free(sk);
+ return NULL;
+ }
+ vsock_addr_init(remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+ rcu_assign_pointer(vsk->remote_addr, remote_addr);
sk->sk_destruct = vsock_sk_destruct;
sk->sk_backlog_rcv = vsock_queue_rcv_skb;
@@ -845,6 +1015,7 @@ static void __vsock_release(struct sock *sk, int level)
static void vsock_sk_destruct(struct sock *sk)
{
struct vsock_sock *vsk = vsock_sk(sk);
+ struct sockaddr_vm_rcu *remote_addr;
vsock_deassign_transport(vsk);
@@ -852,8 +1023,8 @@ static void vsock_sk_destruct(struct sock *sk)
* possibly register the address family with the kernel.
*/
vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
- vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
-
+ remote_addr = rcu_replace_pointer(vsk->remote_addr, NULL, 1);
+ kfree_rcu(remote_addr);
put_cred(vsk->owner);
}
@@ -943,6 +1114,7 @@ static int vsock_getname(struct socket *sock,
struct sock *sk;
struct vsock_sock *vsk;
struct sockaddr_vm *vm_addr;
+ struct sockaddr_vm_rcu *rcu_ptr;
sk = sock->sk;
vsk = vsock_sk(sk);
@@ -951,11 +1123,17 @@ static int vsock_getname(struct socket *sock,
lock_sock(sk);
if (peer) {
+ rcu_read_lock();
if (sock->state != SS_CONNECTED) {
err = -ENOTCONN;
goto out;
}
- vm_addr = &vsk->remote_addr;
+ rcu_ptr = rcu_dereference(vsk->remote_addr);
+ if (!rcu_ptr) {
+ err = -EINVAL;
+ goto out;
+ }
+ vm_addr = &rcu_ptr->addr;
} else {
vm_addr = &vsk->local_addr;
}
@@ -975,6 +1153,8 @@ static int vsock_getname(struct socket *sock,
err = sizeof(*vm_addr);
out:
+ if (peer)
+ rcu_read_unlock();
release_sock(sk);
return err;
}
@@ -1161,7 +1341,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
int err;
struct sock *sk;
struct vsock_sock *vsk;
- struct sockaddr_vm *remote_addr;
+ struct sockaddr_vm stack_addr, *remote_addr;
const struct vsock_transport *transport;
if (msg->msg_flags & MSG_OOB)
@@ -1172,15 +1352,26 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
sk = sock->sk;
vsk = vsock_sk(sk);
- lock_sock(sk);
+ /* If auto-binding is required, acquire the slock to avoid potential
+ * race conditions. Otherwise, do not acquire the lock.
+ *
+ * We know that the first check of local_addr is racy (indicated by
+ * data_race()). By acquiring the lock and then subsequently checking
+ * again if local_addr is bound (inside vsock_auto_bind()), we can
+ * ensure there are no real data races.
+ *
+ * This technique is borrowed by inet_send_prepare().
+ */
+ if (data_race(!vsock_addr_bound(&vsk->local_addr))) {
+ lock_sock(sk);
+ err = vsock_auto_bind(vsk);
+ release_sock(sk);
+ if (err)
+ return err;
+ }
transport = vsk->transport;
- err = vsock_auto_bind(vsk);
- if (err)
- goto out;
-
-
/* If the provided message contains an address, use that. Otherwise
* fall back on the socket's remote handle (if it has been connected).
*/
@@ -1199,18 +1390,26 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}
} else if (sock->state == SS_CONNECTED) {
- remote_addr = &vsk->remote_addr;
+ err = vsock_remote_addr_copy(vsk, &stack_addr);
+ if (err < 0)
+ goto out;
- if (remote_addr->svm_cid == VMADDR_CID_ANY)
- remote_addr->svm_cid = transport->get_local_cid();
+ if (stack_addr.svm_cid == VMADDR_CID_ANY) {
+ stack_addr.svm_cid = transport->get_local_cid();
+ lock_sock(sk_vsock(vsk));
+ vsock_remote_addr_update(vsk, &stack_addr);
+ release_sock(sk_vsock(vsk));
+ }
/* XXX Should connect() or this function ensure remote_addr is
* bound?
*/
- if (!vsock_addr_bound(&vsk->remote_addr)) {
+ if (!vsock_addr_bound(&stack_addr)) {
err = -EINVAL;
goto out;
}
+
+ remote_addr = &stack_addr;
} else {
err = -EINVAL;
goto out;
@@ -1225,7 +1424,6 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
out:
- release_sock(sk);
return err;
}
@@ -1243,8 +1441,7 @@ static int vsock_dgram_connect(struct socket *sock,
err = vsock_addr_cast(addr, addr_len, &remote_addr);
if (err == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) {
lock_sock(sk);
- vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY,
- VMADDR_PORT_ANY);
+ vsock_remote_addr_update_cid_port(vsk, VMADDR_CID_ANY, VMADDR_PORT_ANY);
sock->state = SS_UNCONNECTED;
release_sock(sk);
return 0;
@@ -1263,7 +1460,10 @@ static int vsock_dgram_connect(struct socket *sock,
goto out;
}
- memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
+ err = vsock_remote_addr_update(vsk, remote_addr);
+ if (err < 0)
+ goto out;
+
sock->state = SS_CONNECTED;
/* sock map disallows redirection of non-TCP sockets with sk_state !=
@@ -1399,8 +1599,9 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr,
}
/* Set the remote address that we are connecting to. */
- memcpy(&vsk->remote_addr, remote_addr,
- sizeof(vsk->remote_addr));
+ err = vsock_remote_addr_update(vsk, remote_addr);
+ if (err)
+ goto out;
err = vsock_assign_transport(vsk, NULL);
if (err)
@@ -1831,7 +2032,7 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}
- if (!vsock_addr_bound(&vsk->remote_addr)) {
+ if (!vsock_remote_addr_bound(vsk)) {
err = -EDESTADDRREQ;
goto out;
}
diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c
index a2823b1c5e28..f843bae86b32 100644
--- a/net/vmw_vsock/diag.c
+++ b/net/vmw_vsock/diag.c
@@ -15,8 +15,14 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
u32 portid, u32 seq, u32 flags)
{
struct vsock_sock *vsk = vsock_sk(sk);
+ struct sockaddr_vm remote_addr;
struct vsock_diag_msg *rep;
struct nlmsghdr *nlh;
+ int err;
+
+ err = vsock_remote_addr_copy(vsk, &remote_addr);
+ if (err < 0)
+ return err;
nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep),
flags);
@@ -36,8 +42,8 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
rep->vdiag_shutdown = sk->sk_shutdown;
rep->vdiag_src_cid = vsk->local_addr.svm_cid;
rep->vdiag_src_port = vsk->local_addr.svm_port;
- rep->vdiag_dst_cid = vsk->remote_addr.svm_cid;
- rep->vdiag_dst_port = vsk->remote_addr.svm_port;
+ rep->vdiag_dst_cid = remote_addr.svm_cid;
+ rep->vdiag_dst_port = remote_addr.svm_port;
rep->vdiag_ino = sock_i_ino(sk);
sock_diag_save_cookie(sk, rep->vdiag_cookie);
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index 7cb1a9d2cdb4..462b2ec3e6e9 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -336,9 +336,11 @@ static void hvs_open_connection(struct vmbus_channel *chan)
hvs_addr_init(&vnew->local_addr, if_type);
/* Remote peer is always the host */
- vsock_addr_init(&vnew->remote_addr,
- VMADDR_CID_HOST, VMADDR_PORT_ANY);
- vnew->remote_addr.svm_port = get_port_by_srv_id(if_instance);
+ ret = vsock_remote_addr_update_cid_port(vnew, VMADDR_CID_HOST,
+ get_port_by_srv_id(if_instance));
+ if (ret < 0)
+ goto out;
+
ret = vsock_assign_transport(vnew, vsock_sk(sk));
/* Transport assigned (looking at remote_addr) must be the
* same where we received the request.
@@ -459,13 +461,18 @@ static int hvs_connect(struct vsock_sock *vsk)
{
union hvs_service_id vm, host;
struct hvsock *h = vsk->trans;
+ int err;
vm.srv_id = srv_id_template;
vm.svm_port = vsk->local_addr.svm_port;
h->vm_srv_id = vm.srv_id;
host.srv_id = srv_id_template;
- host.svm_port = vsk->remote_addr.svm_port;
+
+ err = vsock_remote_addr_port(vsk, &host.svm_port);
+ if (err < 0)
+ return err;
+
h->host_srv_id = host.srv_id;
return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index 925acface893..1b87704e516a 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -258,8 +258,9 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
src_cid = t_ops->transport.get_local_cid();
src_port = vsk->local_addr.svm_port;
if (!info->remote_cid) {
- dst_cid = vsk->remote_addr.svm_cid;
- dst_port = vsk->remote_addr.svm_port;
+ ret = vsock_remote_addr_cid_port(vsk, &dst_cid, &dst_port);
+ if (ret < 0)
+ return ret;
} else {
dst_cid = info->remote_cid;
dst_port = info->remote_port;
@@ -1169,7 +1170,9 @@ virtio_transport_recv_connecting(struct sock *sk,
case VIRTIO_VSOCK_OP_RESPONSE:
sk->sk_state = TCP_ESTABLISHED;
sk->sk_socket->state = SS_CONNECTED;
- vsock_insert_connected(vsk);
+ err = vsock_insert_connected(vsk);
+ if (err)
+ goto destroy;
sk->sk_state_change(sk);
break;
case VIRTIO_VSOCK_OP_INVALID:
@@ -1403,9 +1406,8 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
vchild = vsock_sk(child);
vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid),
le32_to_cpu(hdr->dst_port));
- vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid),
- le32_to_cpu(hdr->src_port));
-
+ vsock_remote_addr_update_cid_port(vchild, le64_to_cpu(hdr->src_cid),
+ le32_to_cpu(hdr->src_port));
ret = vsock_assign_transport(vchild, vsk);
/* Transport assigned (looking at remote_addr) must be the same
* where we received the request.
@@ -1420,7 +1422,13 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
if (virtio_transport_space_update(child, skb))
child->sk_write_space(child);
- vsock_insert_connected(vchild);
+ ret = vsock_insert_connected(vchild);
+ if (ret) {
+ release_sock(child);
+ virtio_transport_reset_no_sock(t, skb);
+ sock_put(child);
+ return ret;
+ }
vsock_enqueue_accept(sk, child);
virtio_transport_send_response(vchild, skb);
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index b370070194fa..c0c445e7d925 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -283,18 +283,25 @@ vmci_transport_send_control_pkt(struct sock *sk,
u16 proto,
struct vmci_handle handle)
{
+ struct sockaddr_vm addr_stack;
+ struct sockaddr_vm *remote_addr = &addr_stack;
struct vsock_sock *vsk;
+ int err;
vsk = vsock_sk(sk);
if (!vsock_addr_bound(&vsk->local_addr))
return -EINVAL;
- if (!vsock_addr_bound(&vsk->remote_addr))
+ if (!vsock_remote_addr_bound(vsk))
return -EINVAL;
+ err = vsock_remote_addr_copy(vsk, &addr_stack);
+ if (err < 0)
+ return err;
+
return vmci_transport_alloc_send_control_pkt(&vsk->local_addr,
- &vsk->remote_addr,
+ remote_addr,
type, size, mode,
wait, proto, handle);
}
@@ -317,6 +324,7 @@ static int vmci_transport_send_reset(struct sock *sk,
struct sockaddr_vm *dst_ptr;
struct sockaddr_vm dst;
struct vsock_sock *vsk;
+ int err;
if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST)
return 0;
@@ -326,13 +334,16 @@ static int vmci_transport_send_reset(struct sock *sk,
if (!vsock_addr_bound(&vsk->local_addr))
return -EINVAL;
- if (vsock_addr_bound(&vsk->remote_addr)) {
- dst_ptr = &vsk->remote_addr;
+ if (vsock_remote_addr_bound(vsk)) {
+ err = vsock_remote_addr_copy(vsk, &dst);
+ if (err < 0)
+ return err;
} else {
vsock_addr_init(&dst, pkt->dg.src.context,
pkt->src_port);
- dst_ptr = &dst;
}
+ dst_ptr = &dst;
+
return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, dst_ptr,
VMCI_TRANSPORT_PACKET_TYPE_RST,
0, 0, NULL, VSOCK_PROTO_INVALID,
@@ -490,7 +501,7 @@ static struct sock *vmci_transport_get_pending(
list_for_each_entry(vpending, &vlistener->pending_links,
pending_links) {
- if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
+ if (vsock_remote_addr_equals(vpending, &src) &&
pkt->dst_port == vpending->local_addr.svm_port) {
pending = sk_vsock(vpending);
sock_hold(pending);
@@ -1015,8 +1026,8 @@ static int vmci_transport_recv_listen(struct sock *sk,
vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context,
pkt->dst_port);
- vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
- pkt->src_port);
+ vsock_remote_addr_update_cid_port(vpending, pkt->dg.src.context,
+ pkt->src_port);
err = vsock_assign_transport(vpending, vsock_sk(sk));
/* Transport assigned (looking at remote_addr) must be the same
@@ -1133,6 +1144,7 @@ vmci_transport_recv_connecting_server(struct sock *listener,
{
struct vsock_sock *vpending;
struct vmci_handle handle;
+ unsigned int vpending_remote_cid;
struct vmci_qp *qpair;
bool is_local;
u32 flags;
@@ -1189,8 +1201,13 @@ vmci_transport_recv_connecting_server(struct sock *listener,
/* vpending->local_addr always has a context id so we do not need to
* worry about VMADDR_CID_ANY in this case.
*/
- is_local =
- vpending->remote_addr.svm_cid == vpending->local_addr.svm_cid;
+ err = vsock_remote_addr_cid(vpending, &vpending_remote_cid);
+ if (err < 0) {
+ skerr = EPROTO;
+ goto destroy;
+ }
+
+ is_local = vpending_remote_cid == vpending->local_addr.svm_cid;
flags = VMCI_QPFLAG_ATTACH_ONLY;
flags |= is_local ? VMCI_QPFLAG_LOCAL : 0;
@@ -1203,7 +1220,7 @@ vmci_transport_recv_connecting_server(struct sock *listener,
flags,
vmci_transport_is_trusted(
vpending,
- vpending->remote_addr.svm_cid));
+ vpending_remote_cid));
if (err < 0) {
vmci_transport_send_reset(pending, pkt);
skerr = -err;
@@ -1306,9 +1323,20 @@ vmci_transport_recv_connecting_client(struct sock *sk,
break;
case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE:
case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2:
+ struct sockaddr_vm_rcu *remote_addr;
+
+ rcu_read_lock();
+ remote_addr = rcu_dereference(vsk->remote_addr);
+ if (!remote_addr) {
+ skerr = EPROTO;
+ err = -EINVAL;
+ rcu_read_unlock();
+ goto destroy;
+ }
+
if (pkt->u.size == 0
- || pkt->dg.src.context != vsk->remote_addr.svm_cid
- || pkt->src_port != vsk->remote_addr.svm_port
+ || pkt->dg.src.context != remote_addr->addr.svm_cid
+ || pkt->src_port != remote_addr->addr.svm_port
|| !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)
|| vmci_trans(vsk)->qpair
|| vmci_trans(vsk)->produce_size != 0
@@ -1316,9 +1344,10 @@ vmci_transport_recv_connecting_client(struct sock *sk,
|| vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) {
skerr = EPROTO;
err = -EINVAL;
-
+ rcu_read_unlock();
goto destroy;
}
+ rcu_read_unlock();
err = vmci_transport_recv_connecting_client_negotiate(sk, pkt);
if (err) {
@@ -1379,6 +1408,7 @@ static int vmci_transport_recv_connecting_client_negotiate(
int err;
struct vsock_sock *vsk;
struct vmci_handle handle;
+ unsigned int remote_cid;
struct vmci_qp *qpair;
u32 detach_sub_id;
bool is_local;
@@ -1449,19 +1479,23 @@ static int vmci_transport_recv_connecting_client_negotiate(
/* Make VMCI select the handle for us. */
handle = VMCI_INVALID_HANDLE;
- is_local = vsk->remote_addr.svm_cid == vsk->local_addr.svm_cid;
+
+ err = vsock_remote_addr_cid(vsk, &remote_cid);
+ if (err < 0)
+ goto destroy;
+
+ is_local = remote_cid == vsk->local_addr.svm_cid;
flags = is_local ? VMCI_QPFLAG_LOCAL : 0;
err = vmci_transport_queue_pair_alloc(&qpair,
&handle,
pkt->u.size,
pkt->u.size,
- vsk->remote_addr.svm_cid,
+ remote_cid,
flags,
vmci_transport_is_trusted(
vsk,
- vsk->
- remote_addr.svm_cid));
+ remote_cid));
if (err < 0)
goto destroy;
--
2.30.2