Expand the multishot recv test to include recvmsg. This also checks that sockaddr comes back, and that control messages work properly. Signed-off-by: Dylan Yudaken <dylany@xxxxxx> --- test/recv-multishot.c | 180 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 161 insertions(+), 19 deletions(-) diff --git a/test/recv-multishot.c b/test/recv-multishot.c index 9df8184..a322e43 100644 --- a/test/recv-multishot.c +++ b/test/recv-multishot.c @@ -27,20 +27,45 @@ enum early_error_t { struct args { bool stream; bool wait_each; + bool recvmsg; enum early_error_t early_error; }; +static int check_sockaddr(struct sockaddr_in *in) +{ + struct in_addr expected; + + inet_pton(AF_INET, "127.0.0.1", &expected); + if (in->sin_family != AF_INET) { + fprintf(stderr, "bad family %d\n", (int)htons(in->sin_family)); + return -1; + } + if (memcmp(&expected, &in->sin_addr, sizeof(in->sin_addr))) { + char buff[256]; + const char *addr = inet_ntop(AF_INET, &in->sin_addr, buff, sizeof(buff)); + + fprintf(stderr, "unexpected address %s\n", addr ? addr : "INVALID"); + return -1; + } + return 0; +} + static int test(struct args *args) { int const N = 8; int const N_BUFFS = N * 64; int const N_CQE_OVERFLOW = 4; int const min_cqes = 2; + int const NAME_LEN = sizeof(struct sockaddr_storage); + int const CONTROL_LEN = CMSG_ALIGN(sizeof(struct sockaddr_storage)) + + sizeof(struct cmsghdr); struct io_uring ring; struct io_uring_cqe *cqe; struct io_uring_sqe *sqe; - int fds[2], ret, i, j, total_sent_bytes = 0, total_recv_bytes = 0; + int fds[2], ret, i, j; + int total_sent_bytes = 0, total_recv_bytes = 0, total_dropped_bytes = 0; int send_buff[256]; + int *sent_buffs[N_BUFFS]; int *recv_buffs[N_BUFFS]; int *at; struct io_uring_cqe recv_cqe[N_BUFFS]; @@ -50,7 +75,7 @@ static int test(struct args *args) struct __kernel_timespec timeout = { .tv_sec = 1, }; - + struct msghdr msg; memset(recv_buffs, 0, sizeof(recv_buffs)); @@ -75,21 +100,42 @@ static int test(struct args *args) return ret; } + if (!args->stream) { + bool val = true; + + /* force some cmsgs to come back to us */ + ret = setsockopt(fds[0], IPPROTO_IP, IP_RECVORIGDSTADDR, &val, + sizeof(val)); + if (ret) { + fprintf(stderr, "setsockopt failed %d\n", errno); + goto cleanup; + } + } + for (i = 0; i < ARRAY_SIZE(send_buff); i++) send_buff[i] = i; for (i = 0; i < ARRAY_SIZE(recv_buffs); i++) { /* prepare some different sized buffers */ - int buffer_size = (i % 2 == 0 && args->stream) ? 1 : N * sizeof(int); + int buffer_size = (i % 2 == 0 && (args->stream || args->recvmsg)) ? 1 : N; + + buffer_size *= sizeof(int); + if (args->recvmsg) { + buffer_size += + sizeof(struct io_uring_recvmsg_out) + + NAME_LEN + + CONTROL_LEN; + } - recv_buffs[i] = malloc(sizeof(*at) * buffer_size); + recv_buffs[i] = malloc(buffer_size); if (i > 2 && args->early_error == ERROR_NOT_ENOUGH_BUFFERS) continue; sqe = io_uring_get_sqe(&ring); io_uring_prep_provide_buffers(sqe, recv_buffs[i], - buffer_size * sizeof(*recv_buffs[i]), 1, 7, i); + buffer_size, 1, 7, i); + memset(recv_buffs[i], 0xcc, buffer_size); if (io_uring_submit_and_wait_timeout(&ring, &cqe, 1, &timeout, NULL) != 0) { fprintf(stderr, "provide buffers failed: %d\n", ret); ret = -1; @@ -99,7 +145,19 @@ static int test(struct args *args) } sqe = io_uring_get_sqe(&ring); - io_uring_prep_recv_multishot(sqe, fds[0], NULL, 0, 0); + if (args->recvmsg) { + unsigned int flags = 0; + + if (!args->stream) + flags |= MSG_TRUNC; + + memset(&msg, 0, sizeof(msg)); + msg.msg_namelen = NAME_LEN; + msg.msg_controllen = CONTROL_LEN; + io_uring_prep_recvmsg_multishot(sqe, fds[0], &msg, flags); + } else { + io_uring_prep_recv_multishot(sqe, fds[0], NULL, 0, 0); + } sqe->flags |= IOSQE_BUFFER_SELECT; sqe->buf_group = 7; io_uring_sqe_set_data64(sqe, 1234); @@ -111,6 +169,7 @@ static int test(struct args *args) int to_send = sizeof(*at) * (i+1); total_sent_bytes += to_send; + sent_buffs[i] = at; if (send(fds[1], at, to_send, 0) != to_send) { if (early_error_started) break; @@ -202,9 +261,12 @@ static int test(struct args *args) (args->early_error == ERROR_EARLY_OVERFLOW && !args->wait_each && i == N_CQE_OVERFLOW); int *this_recv; + int orig_payload_size = cqe->res; if (should_be_last) { + int used_res = cqe->res; + if (!is_last) { fprintf(stderr, "not last cqe had error %d\n", i); goto cleanup; @@ -234,7 +296,22 @@ static int test(struct args *args) break; case ERROR_NONE: case ERROR_EARLY_CLOSE_SENDER: - if (cqe->res != 0) { + if (args->recvmsg && (cqe->flags & IORING_CQE_F_BUFFER)) { + void *buff = recv_buffs[cqe->flags >> 16]; + struct io_uring_recvmsg_out *o = + io_uring_recvmsg_validate(buff, cqe->res, &msg); + + if (!o) { + fprintf(stderr, "invalid buff\n"); + goto cleanup; + } + if (o->payloadlen != 0) { + fprintf(stderr, "expected 0 payloadlen, got %u\n", + o->payloadlen); + goto cleanup; + } + used_res = 0; + } else if (cqe->res != 0) { fprintf(stderr, "early error: res %d\n", cqe->res); goto cleanup; } @@ -254,7 +331,7 @@ static int test(struct args *args) goto cleanup; } - if (cqe->res <= 0) + if (used_res <= 0) continue; } else { if (!(cqe->flags & IORING_CQE_F_MORE)) { @@ -268,7 +345,61 @@ static int test(struct args *args) goto cleanup; } + this_recv = recv_buffs[cqe->flags >> 16]; + + if (args->recvmsg) { + struct io_uring_recvmsg_out *o = io_uring_recvmsg_validate( + this_recv, cqe->res, &msg); + + if (!o) { + fprintf(stderr, "bad recvmsg\n"); + goto cleanup; + } + orig_payload_size = o->payloadlen; + + if (!args->stream) { + orig_payload_size = o->payloadlen; + + struct cmsghdr *cmsg; + + if (o->namelen < sizeof(struct sockaddr_in)) { + fprintf(stderr, "bad addr len %d", + o->namelen); + goto cleanup; + } + if (check_sockaddr((struct sockaddr_in *)io_uring_recvmsg_name(o))) + goto cleanup; + + cmsg = io_uring_recvmsg_cmsg_firsthdr(o, &msg); + if (!cmsg || + cmsg->cmsg_level != IPPROTO_IP || + cmsg->cmsg_type != IP_RECVORIGDSTADDR) { + fprintf(stderr, "bad cmsg"); + goto cleanup; + } + if (check_sockaddr((struct sockaddr_in *)CMSG_DATA(cmsg))) + goto cleanup; + cmsg = io_uring_recvmsg_cmsg_nexthdr(o, &msg, cmsg); + if (cmsg) { + fprintf(stderr, "unexpected extra cmsg\n"); + goto cleanup; + } + + } + + this_recv = (int *)io_uring_recvmsg_payload(o, &msg); + cqe->res = io_uring_recvmsg_payload_length(o, cqe->res, &msg); + if (o->payloadlen != cqe->res) { + if (!(o->flags & MSG_TRUNC)) { + fprintf(stderr, "expected truncated flag\n"); + goto cleanup; + } + total_dropped_bytes += (o->payloadlen - cqe->res); + } + } + total_recv_bytes += cqe->res; + if (cqe->res % 4 != 0) { /* * doesn't seem to happen in practice, would need some @@ -278,9 +409,20 @@ static int test(struct args *args) goto cleanup; } - /* check buffer arrived in order (for tcp) */ - this_recv = recv_buffs[cqe->flags >> 16]; - for (j = 0; args->stream && j < cqe->res / 4; j++) { + /* + * for tcp: check buffer arrived in order + * for udp: based on size validate data based on size + */ + if (!args->stream) { + int sent_idx = orig_payload_size / sizeof(*at) - 1; + + if (sent_idx < 0 || sent_idx > N) { + fprintf(stderr, "Bad sent idx: %d\n", sent_idx); + goto cleanup; + } + at = sent_buffs[sent_idx]; + } + for (j = 0; j < cqe->res / 4; j++) { int sent = *at++; int recv = *this_recv++; @@ -291,15 +433,14 @@ static int test(struct args *args) } } - if (args->early_error == ERROR_NONE && total_recv_bytes < total_sent_bytes) { + if (args->early_error == ERROR_NONE && + total_recv_bytes + total_dropped_bytes < total_sent_bytes) { fprintf(stderr, - "missing recv: recv=%d sent=%d\n", total_recv_bytes, total_sent_bytes); + "missing recv: recv=%d dropped=%d sent=%d\n", + total_recv_bytes, total_sent_bytes, total_dropped_bytes); goto cleanup; } - /* check the final one */ - cqe = &recv_cqe[recv_cqes-1]; - ret = 0; cleanup: for (i = 0; i < ARRAY_SIZE(recv_buffs); i++) @@ -320,18 +461,19 @@ int main(int argc, char *argv[]) if (argc > 1) return T_EXIT_SKIP; - for (loop = 0; loop < 4; loop++) { + for (loop = 0; loop < 8; loop++) { struct args a = { .stream = loop & 0x01, .wait_each = loop & 0x2, + .recvmsg = loop & 0x04, }; for (early_error = 0; early_error < ERROR_EARLY_LAST; early_error++) { a.early_error = (enum early_error_t)early_error; ret = test(&a); if (ret) { fprintf(stderr, - "test stream=%d wait_each=%d early_error=%d failed\n", - a.stream, a.wait_each, a.early_error); + "test stream=%d wait_each=%d recvmsg=%d early_error=%d failed\n", + a.stream, a.wait_each, a.recvmsg, a.early_error); return T_EXIT_FAIL; } if (no_recv_mshot) -- 2.30.2