Adds a new argument, named "src_cid", to let them know which `virtio_vsock` to be selected. Signed-off-by: Xuewei Niu <niuxuewei.nxw@xxxxxxxxxxxx> --- include/net/af_vsock.h | 2 +- net/vmw_vsock/af_vsock.c | 15 +++++++++++++-- net/vmw_vsock/virtio_transport.c | 4 ++-- net/vmw_vsock/vsock_loopback.c | 4 ++-- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h index 0151296a0bc5..25f7dc3d602d 100644 --- a/include/net/af_vsock.h +++ b/include/net/af_vsock.h @@ -143,7 +143,7 @@ struct vsock_transport { int flags); int (*seqpacket_enqueue)(struct vsock_sock *vsk, struct msghdr *msg, size_t len); - bool (*seqpacket_allow)(u32 remote_cid); + bool (*seqpacket_allow)(u32 src_cid, u32 remote_cid); u32 (*seqpacket_has_data)(struct vsock_sock *vsk); /* Notification. */ diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index da06ddc940cd..3b34be802bf2 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -470,10 +470,12 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) { const struct vsock_transport *new_transport; struct sock *sk = sk_vsock(vsk); - unsigned int remote_cid = vsk->remote_addr.svm_cid; + unsigned int src_cid, remote_cid; __u8 remote_flags; int ret; + remote_cid = vsk->remote_addr.svm_cid; + /* If the packet is coming with the source and destination CIDs higher * than VMADDR_CID_HOST, then a vsock channel where all the packets are * forwarded to the host should be established. Then the host will @@ -527,8 +529,17 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) return -ENODEV; if (sk->sk_type == SOCK_SEQPACKET) { + if (vsk->local_addr.svm_cid == VMADDR_CID_ANY) { + if (new_transport->get_default_cid) + src_cid = new_transport->get_default_cid(); + else + src_cid = new_transport->get_local_cid(); + } else { + src_cid = vsk->local_addr.svm_cid; + } + if (!new_transport->seqpacket_allow || - !new_transport->seqpacket_allow(remote_cid)) { + !new_transport->seqpacket_allow(src_cid, remote_cid)) { module_put(new_transport->module); return -ESOCKTNOSUPPORT; } diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index 998b22e5ce36..0bddcbd906a2 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -615,14 +615,14 @@ static struct virtio_transport virtio_transport = { .can_msgzerocopy = virtio_transport_can_msgzerocopy, }; -static bool virtio_transport_seqpacket_allow(u32 remote_cid) +static bool virtio_transport_seqpacket_allow(u32 src_cid, u32 remote_cid) { struct virtio_vsock *vsock; bool seqpacket_allow; seqpacket_allow = false; rcu_read_lock(); - vsock = rcu_dereference(the_virtio_vsock); + vsock = virtio_transport_get_virtio_vsock(src_cid); if (vsock) seqpacket_allow = vsock->seqpacket_allow; rcu_read_unlock(); diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c index 6dea6119f5b2..b94358f5bb2c 100644 --- a/net/vmw_vsock/vsock_loopback.c +++ b/net/vmw_vsock/vsock_loopback.c @@ -46,7 +46,7 @@ static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk) return 0; } -static bool vsock_loopback_seqpacket_allow(u32 remote_cid); +static bool vsock_loopback_seqpacket_allow(u32 src_cid, u32 remote_cid); static bool vsock_loopback_msgzerocopy_allow(void) { return true; @@ -104,7 +104,7 @@ static struct virtio_transport loopback_transport = { .send_pkt = vsock_loopback_send_pkt, }; -static bool vsock_loopback_seqpacket_allow(u32 remote_cid) +static bool vsock_loopback_seqpacket_allow(u32 src_cid, u32 remote_cid) { return true; } -- 2.34.1