From: Ilya Lesokhin <ilyal@xxxxxxxxxxxx> Signed-off-by: Dave Watson <davejwatson@xxxxxx> Signed-off-by: Ilya Lesokhin <ilyal@xxxxxxxxxxxx> Signed-off-by: Aviad Yehezkel <aviadye@xxxxxxxxxxxx> --- MAINTAINERS | 1 + include/net/tls.h | 44 ++++ net/tls/Makefile | 2 +- net/tls/tls_main.c | 34 +-- net/tls/tls_sw.c | 729 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 794 insertions(+), 16 deletions(-) create mode 100644 net/tls/tls_sw.c diff --git a/MAINTAINERS b/MAINTAINERS index e3b70c3..413c1d9 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -8491,6 +8491,7 @@ M: Ilya Lesokhin <ilyal@xxxxxxxxxxxx> M: Aviad Yehezkel <aviadye@xxxxxxxxxxxx> M: Boris Pismenny <borisp@xxxxxxxxxxxx> M: Haggai Eran <haggaie@xxxxxxxxxxxx> +M: Dave Watson <davejwatson@xxxxxx> L: netdev@xxxxxxxxxxxxxxx T: git git://git.kernel.org/pub/scm/linux/kernel/git/davem/net.git T: git git://git.kernel.org/pub/scm/linux/kernel/git/davem/net-next.git diff --git a/include/net/tls.h b/include/net/tls.h index f7f0cde..bb1f41e 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -48,6 +48,7 @@ #define TLS_CRYPTO_INFO_READY(info) ((info)->cipher_type) #define TLS_IS_STATE_HW(info) ((info)->state == TLS_STATE_HW) +#define TLS_IS_STATE_SW(info) ((info)->state == TLS_STATE_SW) #define TLS_RECORD_TYPE_DATA 0x17 @@ -68,6 +69,37 @@ struct tls_offload_context { spinlock_t lock; /* protects records list */ }; +#define TLS_DATA_PAGES (TLS_MAX_PAYLOAD_SIZE / PAGE_SIZE) +/* +1 for aad, +1 for tag, +1 for chaining */ +#define TLS_SG_DATA_SIZE (TLS_DATA_PAGES + 3) +#define ALG_MAX_PAGES 16 /* for skb_to_sgvec */ +#define TLS_AAD_SPACE_SIZE 21 +#define TLS_AAD_SIZE 13 +#define TLS_TAG_SIZE 16 + +#define TLS_NONCE_SIZE 8 +#define TLS_PREPEND_SIZE (TLS_HEADER_SIZE + TLS_NONCE_SIZE) +#define TLS_OVERHEAD (TLS_PREPEND_SIZE + TLS_TAG_SIZE) + +struct tls_sw_context { + struct sock *sk; + void (*sk_write_space)(struct sock *sk); + struct crypto_aead *aead_send; + + /* Sending context */ + struct scatterlist sg_tx_data[TLS_SG_DATA_SIZE]; + struct scatterlist sg_tx_data2[ALG_MAX_PAGES + 1]; + char aad_send[TLS_AAD_SPACE_SIZE]; + char tag_send[TLS_TAG_SIZE]; + skb_frag_t tx_frag; + int wmem_len; + int order_npages; + struct scatterlist sgaad_send[2]; + struct scatterlist sgtag_send[2]; + struct sk_buff_head tx_queue; + int unsent; +}; + struct tls_context { union { struct tls_crypto_info crypto_send; @@ -102,6 +134,12 @@ int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); int tls_device_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags); +int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx); +void tls_clear_sw_offload(struct sock *sk); +int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); +int tls_sw_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags); + struct tls_record_info *tls_get_record(struct tls_offload_context *context, u32 seq); @@ -174,6 +212,12 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk) return sk->sk_user_data; } +static inline struct tls_sw_context *tls_sw_ctx( + const struct tls_context *tls_ctx) +{ + return (struct tls_sw_context *)tls_ctx->priv_ctx; +} + static inline struct tls_offload_context *tls_offload_ctx( const struct tls_context *tls_ctx) { diff --git a/net/tls/Makefile b/net/tls/Makefile index 65e5677..61457e0 100644 --- a/net/tls/Makefile +++ b/net/tls/Makefile @@ -4,4 +4,4 @@ obj-$(CONFIG_TLS) += tls.o -tls-y := tls_main.o tls_device.o +tls-y := tls_main.o tls_device.o tls_sw.o diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 6a3df25..a4efd02 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -46,6 +46,7 @@ MODULE_DESCRIPTION("Transport Layer Security Support"); MODULE_LICENSE("Dual BSD/GPL"); static struct proto tls_device_prot; +static struct proto tls_sw_prot; int tls_push_frags(struct sock *sk, struct tls_context *ctx, @@ -188,13 +189,10 @@ int tls_sk_query(struct sock *sk, int optname, char __user *optval, rc = -EINVAL; goto out; } - if (TLS_IS_STATE_HW(crypto_info)) { - lock_sock(sk); - memcpy(crypto_info_aes_gcm_128->iv, - ctx->iv, - TLS_CIPHER_AES_GCM_128_IV_SIZE); - release_sock(sk); - } + lock_sock(sk); + memcpy(crypto_info_aes_gcm_128->iv, ctx->iv, + TLS_CIPHER_AES_GCM_128_IV_SIZE); + release_sock(sk); rc = copy_to_user(optval, crypto_info_aes_gcm_128, sizeof(*crypto_info_aes_gcm_128)); @@ -224,6 +222,7 @@ int tls_sk_attach(struct sock *sk, int optname, char __user *optval, struct tls_context *ctx = tls_get_ctx(sk); struct tls_crypto_info *crypto_info; bool allocated_tls_ctx = false; + struct proto *prot = NULL; if (!optval || (optlen < sizeof(*crypto_info))) { rc = -EINVAL; @@ -267,12 +266,6 @@ int tls_sk_attach(struct sock *sk, int optname, char __user *optval, goto err_sk_user_data; } - /* currently we support only HW offload */ - if (!TLS_IS_STATE_HW(crypto_info)) { - rc = -ENOPROTOOPT; - goto err_crypto_info; - } - /* check version */ if (crypto_info->version != TLS_1_2_VERSION) { rc = -ENOTSUPP; @@ -306,6 +299,12 @@ int tls_sk_attach(struct sock *sk, int optname, char __user *optval, if (TLS_IS_STATE_HW(crypto_info)) { rc = tls_set_device_offload(sk, ctx); + prot = &tls_device_prot; + if (rc) + goto err_crypto_info; + } else if (TLS_IS_STATE_SW(crypto_info)) { + rc = tls_set_sw_offload(sk, ctx); + prot = &tls_sw_prot; if (rc) goto err_crypto_info; } @@ -315,8 +314,9 @@ int tls_sk_attach(struct sock *sk, int optname, char __user *optval, goto err_set_device_offload; } - /* TODO: add protection */ - sk->sk_prot = &tls_device_prot; + rc = 0; + + sk->sk_prot = prot; goto out; err_set_device_offload: @@ -337,6 +337,10 @@ static int __init tls_init(void) tls_device_prot.sendmsg = tls_device_sendmsg; tls_device_prot.sendpage = tls_device_sendpage; + tls_sw_prot = tcp_prot; + tls_sw_prot.sendmsg = tls_sw_sendmsg; + tls_sw_prot.sendpage = tls_sw_sendpage; + return 0; } diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c new file mode 100644 index 0000000..4698dc7 --- /dev/null +++ b/net/tls/tls_sw.c @@ -0,0 +1,729 @@ +/* + * af_tls: TLS socket + * + * Copyright (C) 2016 + * + * Original authors: + * Fridolin Pokorny <fridolin.pokorny@xxxxxxxxx> + * Nikos Mavrogiannopoulos <nmav@xxxxxxxxxx> + * Dave Watson <davejwatson@xxxxxx> + * Lance Chao <lancerchao@xxxxxx> + * + * Based on RFC 5288, RFC 6347, RFC 5246, RFC 6655 + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License as + * published by the Free Software Foundation; either version 2 of the + * License, or (at your option) any later version. + */ + +#include <linux/module.h> +#include <net/tcp.h> +#include <net/inet_common.h> +#include <linux/highmem.h> +#include <linux/netdevice.h> +#include <crypto/aead.h> + +#include <net/tls.h> + +static int tls_kernel_sendpage(struct sock *sk, int flags); + +static inline void tls_make_aad(struct sock *sk, + int recv, + char *buf, + size_t size, + char *nonce_explicit, + unsigned char record_type) +{ + memcpy(buf, nonce_explicit, TLS_NONCE_SIZE); + + buf[8] = record_type; + buf[9] = TLS_1_2_VERSION_MAJOR; + buf[10] = TLS_1_2_VERSION_MINOR; + buf[11] = size >> 8; + buf[12] = size & 0xFF; +} + +static int tls_do_encryption(struct sock *sk, struct scatterlist *sgin, + struct scatterlist *sgout, size_t data_len, + struct sk_buff *skb) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + int ret; + unsigned int req_size = sizeof(struct aead_request) + + crypto_aead_reqsize(ctx->aead_send); + struct aead_request *aead_req; + + pr_debug("tls_do_encryption %p\n", sk); + + aead_req = kmalloc(req_size, GFP_ATOMIC); + + if (!aead_req) + return -ENOMEM; + + aead_request_set_tfm(aead_req, ctx->aead_send); + aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); + aead_request_set_crypt(aead_req, sgin, sgout, data_len, tls_ctx->iv); + + ret = crypto_aead_encrypt(aead_req); + + kfree(aead_req); + if (ret < 0) + return ret; + tls_kernel_sendpage(sk, MSG_DONTWAIT); + + return ret; +} + +/* Allocates enough pages to hold the decrypted data, as well as + * setting ctx->sg_tx_data to the pages + */ +static int tls_pre_encrypt(struct sock *sk, size_t data_len) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + int i; + unsigned int npages; + size_t aligned_size; + size_t encrypt_len; + struct scatterlist *sg; + int ret = 0; + struct page *tx_pages; + + encrypt_len = data_len + TLS_OVERHEAD; + npages = encrypt_len / PAGE_SIZE; + aligned_size = npages * PAGE_SIZE; + if (aligned_size < encrypt_len) + npages++; + + ctx->order_npages = order_base_2(npages); + WARN_ON(ctx->order_npages < 0 || ctx->order_npages > 3); + /* The first entry in sg_tx_data is AAD so skip it */ + sg_init_table(ctx->sg_tx_data, TLS_SG_DATA_SIZE); + sg_set_buf(&ctx->sg_tx_data[0], ctx->aad_send, sizeof(ctx->aad_send)); + tx_pages = alloc_pages(GFP_KERNEL | __GFP_COMP, + ctx->order_npages); + if (!tx_pages) { + ret = -ENOMEM; + return ret; + } + + sg = ctx->sg_tx_data + 1; + /* For the first page, leave room for prepend. It will be + * copied into the page later + */ + sg_set_page(sg, tx_pages, PAGE_SIZE - TLS_PREPEND_SIZE, + TLS_PREPEND_SIZE); + for (i = 1; i < npages; i++) + sg_set_page(sg + i, tx_pages + i, PAGE_SIZE, 0); + + __skb_frag_set_page(&ctx->tx_frag, tx_pages); + + return ret; +} + +static void tls_release_tx_frag(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct page *tx_page = skb_frag_page(&ctx->tx_frag); + + if (!tls_is_pending_open_record(tls_ctx) && tx_page) { + struct sk_buff *head; + /* Successfully sent the whole packet, account for it*/ + + head = skb_peek(&ctx->tx_queue); + skb_dequeue(&ctx->tx_queue); + sk->sk_wmem_queued -= ctx->wmem_len; + sk_mem_uncharge(sk, ctx->wmem_len); + ctx->wmem_len = 0; + kfree_skb(head); + ctx->unsent -= skb_frag_size(&ctx->tx_frag) - TLS_OVERHEAD; + tls_increment_seqno(tls_ctx->iv, sk); + __free_pages(tx_page, + ctx->order_npages); + __skb_frag_set_page(&ctx->tx_frag, NULL); + } + ctx->sk_write_space(sk); +} + +static int tls_kernel_sendpage(struct sock *sk, int flags) +{ + int ret; + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + + skb_frag_size_add(&ctx->tx_frag, TLS_OVERHEAD); + ret = tls_push_frags(sk, tls_ctx, &ctx->tx_frag, 1, 0, flags); + if (ret >= 0) + tls_release_tx_frag(sk); + else if (ret != -EAGAIN) + tls_err_abort(sk); + + return ret; +} + +static int tls_push_zerocopy(struct sock *sk, struct scatterlist *sgin, + int pages, int bytes, unsigned char record_type) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + int ret; + + tls_make_aad(sk, 0, ctx->aad_send, bytes, tls_ctx->iv, record_type); + + sg_chain(ctx->sgaad_send, 2, sgin); + //sg_unmark_end(&sgin[pages - 1]); + sg_chain(sgin, pages + 1, ctx->sgtag_send); + ret = sg_nents_for_len(ctx->sgaad_send, bytes + 13 + 16); + + ret = tls_pre_encrypt(sk, bytes); + if (ret < 0) + goto out; + + tls_fill_prepend(tls_ctx, + page_address(skb_frag_page(&ctx->tx_frag)), + bytes, record_type); + + skb_frag_size_set(&ctx->tx_frag, bytes); + + ret = tls_do_encryption(sk, + ctx->sgaad_send, + ctx->sg_tx_data, + bytes, NULL); + + if (ret < 0) + goto out; + +out: + if (ret < 0) { + sk->sk_err = EPIPE; + return ret; + } + + return 0; +} + +static int tls_push(struct sock *sk, unsigned char record_type) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + int bytes = min_t(int, ctx->unsent, (int)TLS_MAX_PAYLOAD_SIZE); + int nsg, ret = 0; + struct sk_buff *head = skb_peek(&ctx->tx_queue); + + if (!head) + return 0; + + bytes = min_t(int, bytes, head->len); + + sg_init_table(ctx->sg_tx_data2, ARRAY_SIZE(ctx->sg_tx_data2)); + nsg = skb_to_sgvec(head, &ctx->sg_tx_data2[0], 0, bytes); + + /* The length of sg into decryption must not be over + * ALG_MAX_PAGES. The aad takes the first sg, so the payload + * must be less than ALG_MAX_PAGES - 1 + */ + if (nsg > ALG_MAX_PAGES - 1) { + ret = -EBADMSG; + goto out; + } + + tls_make_aad(sk, 0, ctx->aad_send, bytes, tls_ctx->iv, record_type); + + sg_chain(ctx->sgaad_send, 2, ctx->sg_tx_data2); + sg_chain(ctx->sg_tx_data2, + nsg + 1, + ctx->sgtag_send); + + ret = tls_pre_encrypt(sk, bytes); + if (ret < 0) + goto out; + + tls_fill_prepend(tls_ctx, + page_address(skb_frag_page(&ctx->tx_frag)), + bytes, record_type); + + skb_frag_size_set(&ctx->tx_frag, bytes); + tls_ctx->pending_offset = 0; + head->sk = sk; + + ret = tls_do_encryption(sk, + ctx->sgaad_send, + ctx->sg_tx_data, + bytes, head); + + if (ret < 0) + goto out; + +out: + if (ret < 0) { + sk->sk_err = EPIPE; + return ret; + } + + return 0; +} + +static int zerocopy_from_iter(struct iov_iter *from, + struct scatterlist *sg, int *bytes) +{ + //int len = iov_iter_count(from); + int n = 0; + + if (bytes) + *bytes = 0; + + //TODO pass in number of pages + while (iov_iter_count(from) && n < MAX_SKB_FRAGS - 1) { + struct page *pages[MAX_SKB_FRAGS]; + size_t start; + ssize_t copied; + int j = 0; + + if (bytes && *bytes >= TLS_MAX_PAYLOAD_SIZE) + break; + + copied = iov_iter_get_pages(from, pages, TLS_MAX_PAYLOAD_SIZE, + MAX_SKB_FRAGS - n, &start); + if (bytes) + *bytes += copied; + if (copied < 0) + return -EFAULT; + + iov_iter_advance(from, copied); + + while (copied) { + int size = min_t(int, copied, PAGE_SIZE - start); + + sg_set_page(&sg[n], pages[j], size, start); + start = 0; + copied -= size; + j++; + n++; + } + } + return n; +} + +int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + int ret = 0; + long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + bool eor = !(msg->msg_flags & MSG_MORE); + struct sk_buff *skb = NULL; + size_t copy, copied = 0; + unsigned char record_type = TLS_RECORD_TYPE_DATA; + + lock_sock(sk); + + if (msg->msg_flags & MSG_OOB) { + if (!eor || ctx->unsent) { + ret = -EINVAL; + goto send_end; + } + + ret = copy_from_iter(&record_type, 1, &msg->msg_iter); + if (ret != 1) { + return -EFAULT; + goto send_end; + } + } + + while (msg_data_left(msg)) { + bool merge = true; + int i; + struct page_frag *pfrag; + + if (sk->sk_err) + goto send_end; + if (!sk_stream_memory_free(sk)) + goto wait_for_memory; + + skb = skb_peek_tail(&ctx->tx_queue); + // Try for zerocopy + if (!skb && !skb_frag_page(&ctx->tx_frag) && eor) { + int pages; + int err; + // TODO can send partial pages? + int page_count = iov_iter_npages(&msg->msg_iter, + ALG_MAX_PAGES); + struct scatterlist sgin[ALG_MAX_PAGES + 1]; + int bytes; + + sg_init_table(sgin, ALG_MAX_PAGES + 1); + + if (page_count >= ALG_MAX_PAGES) + goto reg_send; + + // TODO check pages? + err = zerocopy_from_iter(&msg->msg_iter, &sgin[0], + &bytes); + pages = err; + ctx->unsent += bytes; + if (err < 0) + goto send_end; + + // Try to send msg + tls_push_zerocopy(sk, sgin, pages, bytes, record_type); + for (; pages > 0; pages--) + put_page(sg_page(&sgin[pages - 1])); + if (err < 0) { + tls_err_abort(sk); + goto send_end; + } + continue; + } + +reg_send: + while (!skb) { + skb = alloc_skb(0, sk->sk_allocation); + if (skb) + __skb_queue_tail(&ctx->tx_queue, skb); + } + + i = skb_shinfo(skb)->nr_frags; + pfrag = sk_page_frag(sk); + + if (!sk_page_frag_refill(sk, pfrag)) + goto wait_for_memory; + + if (!skb_can_coalesce(skb, i, pfrag->page, + pfrag->offset)) { + if (i == ALG_MAX_PAGES) { + struct sk_buff *tskb; + + tskb = alloc_skb(0, sk->sk_allocation); + if (!tskb) + goto wait_for_memory; + + if (skb) + skb->next = tskb; + else + __skb_queue_tail(&ctx->tx_queue, + tskb); + + skb = tskb; + skb->ip_summed = CHECKSUM_UNNECESSARY; + continue; + } + merge = false; + } + + copy = min_t(int, msg_data_left(msg), + pfrag->size - pfrag->offset); + copy = min_t(int, copy, TLS_MAX_PAYLOAD_SIZE - ctx->unsent); + + if (!sk_wmem_schedule(sk, copy)) + goto wait_for_memory; + + ret = skb_copy_to_page_nocache(sk, &msg->msg_iter, skb, + pfrag->page, + pfrag->offset, + copy); + ctx->wmem_len += copy; + if (ret) + goto send_end; + + /* Update the skb. */ + if (merge) { + skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy); + } else { + skb_fill_page_desc(skb, i, pfrag->page, + pfrag->offset, copy); + get_page(pfrag->page); + } + + pfrag->offset += copy; + copied += copy; + ctx->unsent += copy; + + if (ctx->unsent >= TLS_MAX_PAYLOAD_SIZE) { + ret = tls_push(sk, record_type); + if (ret) + goto send_end; + } + + continue; + +wait_for_memory: + ret = tls_push(sk, record_type); + if (ret) + goto send_end; +//push_wait: + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + ret = sk_stream_wait_memory(sk, &timeo); + if (ret) + goto send_end; + } + + if (eor) + ret = tls_push(sk, record_type); + +send_end: + ret = sk_stream_error(sk, msg->msg_flags, ret); + + /* make sure we wake any epoll edge trigger waiter */ + if (unlikely(skb_queue_len(&ctx->tx_queue) == 0 && ret == -EAGAIN)) + sk->sk_write_space(sk); + + release_sock(sk); + return ret < 0 ? ret : size; +} + +void tls_sw_sk_destruct(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct page *tx_page = skb_frag_page(&ctx->tx_frag); + + crypto_free_aead(ctx->aead_send); + + if (tx_page) + __free_pages(tx_page, ctx->order_npages); + + skb_queue_purge(&ctx->tx_queue); + tls_sk_destruct(sk, tls_ctx); +} + +int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) +{ + char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE + + TLS_CIPHER_AES_GCM_128_SALT_SIZE]; + struct tls_crypto_info *crypto_info; + struct tls_crypto_info_aes_gcm_128 *gcm_128_info; + struct tls_sw_context *sw_ctx; + u16 nonece_size, tag_size, iv_size; + char *iv; + int rc = 0; + + if (!ctx) { + rc = -EINVAL; + goto out; + } + + if (ctx->priv_ctx) { + rc = -EEXIST; + goto out; + } + + sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL); + if (!sw_ctx) { + rc = -ENOMEM; + goto out; + } + + ctx->priv_ctx = (struct tls_offload_context *)sw_ctx; + + crypto_info = &ctx->crypto_send; + switch (crypto_info->cipher_type) { + case TLS_CIPHER_AES_GCM_128: { + nonece_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; + tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; + iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; + iv = ((struct tls_crypto_info_aes_gcm_128 *)crypto_info)->iv; + gcm_128_info = + (struct tls_crypto_info_aes_gcm_128 *)crypto_info; + break; + } + default: + rc = -EINVAL; + goto out; + } + + ctx->prepand_size = TLS_HEADER_SIZE + nonece_size; + ctx->tag_size = tag_size; + ctx->iv_size = iv_size; + ctx->iv = kmalloc(iv_size, GFP_KERNEL); + if (!ctx->iv) { + rc = ENOMEM; + goto out; + } + memcpy(ctx->iv, iv, iv_size); + + /* Preallocation for sending + * scatterlist: AAD | data | TAG (for crypto API) + * vec: HEADER | data | TAG + */ + sg_init_table(sw_ctx->sg_tx_data, TLS_SG_DATA_SIZE); + sg_set_buf(&sw_ctx->sg_tx_data[0], sw_ctx->aad_send, + sizeof(sw_ctx->aad_send)); + + sg_set_buf(sw_ctx->sg_tx_data + TLS_SG_DATA_SIZE - 2, + sw_ctx->tag_send, sizeof(sw_ctx->tag_send)); + sg_mark_end(sw_ctx->sg_tx_data + TLS_SG_DATA_SIZE - 1); + + sg_init_table(sw_ctx->sgaad_send, 2); + sg_init_table(sw_ctx->sgtag_send, 2); + + sg_set_buf(&sw_ctx->sgaad_send[0], sw_ctx->aad_send, + sizeof(sw_ctx->aad_send)); + /* chaining to tag is performed on actual data size when sending */ + sg_set_buf(&sw_ctx->sgtag_send[0], sw_ctx->tag_send, + sizeof(sw_ctx->tag_send)); + + sg_unmark_end(&sw_ctx->sgaad_send[1]); + + if (!sw_ctx->aead_send) { + sw_ctx->aead_send = + crypto_alloc_aead("rfc5288(gcm(aes))", + CRYPTO_ALG_INTERNAL, 0); + if (IS_ERR(sw_ctx->aead_send)) { + rc = PTR_ERR(sw_ctx->aead_send); + sw_ctx->aead_send = NULL; + pr_err("bind fail\n"); // TODO + goto out; + } + } + + sk->sk_destruct = tls_sw_sk_destruct; + sw_ctx->sk_write_space = ctx->sk_write_space; + ctx->sk_write_space = tls_release_tx_frag; + + skb_queue_head_init(&sw_ctx->tx_queue); + sw_ctx->sk = sk; + + memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE); + memcpy(keyval + TLS_CIPHER_AES_GCM_128_KEY_SIZE, gcm_128_info->salt, + TLS_CIPHER_AES_GCM_128_SALT_SIZE); + + rc = crypto_aead_setkey(sw_ctx->aead_send, keyval, + TLS_CIPHER_AES_GCM_128_KEY_SIZE + + TLS_CIPHER_AES_GCM_128_SALT_SIZE); + if (rc) + goto out; + + rc = crypto_aead_setauthsize(sw_ctx->aead_send, TLS_TAG_SIZE); + if (rc) + goto out; + +out: + return rc; +} + +int tls_sw_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + int ret = 0, i; + long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + bool eor; + struct sk_buff *skb = NULL; + size_t queued = 0; + unsigned char record_type = TLS_RECORD_TYPE_DATA; + + if (flags & MSG_SENDPAGE_NOTLAST) + flags |= MSG_MORE; + + /* No MSG_EOR from splice, only look at MSG_MORE */ + eor = !(flags & MSG_MORE); + + lock_sock(sk); + + if (flags & MSG_OOB) { + ret = -ENOTSUPP; + goto sendpage_end; + } + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + /* Call the sk_stream functions to manage the sndbuf mem. */ + while (size > 0) { + size_t send_size = min(size, TLS_MAX_PAYLOAD_SIZE); + + if (!sk_stream_memory_free(sk) || + (ctx->unsent + send_size > TLS_MAX_PAYLOAD_SIZE)) { + ret = tls_push(sk, record_type); + if (ret) + goto sendpage_end; + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + ret = sk_stream_wait_memory(sk, &timeo); + if (ret) + goto sendpage_end; + } + + if (sk->sk_err) + goto sendpage_end; + + skb = skb_peek_tail(&ctx->tx_queue); + if (skb) { + i = skb_shinfo(skb)->nr_frags; + + if (skb_can_coalesce(skb, i, page, offset)) { + skb_frag_size_add( + &skb_shinfo(skb)->frags[i - 1], + send_size); + skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG; + goto coalesced; + } + + if (i >= ALG_MAX_PAGES) { + struct sk_buff *tskb; + + tskb = alloc_skb(0, sk->sk_allocation); + while (!tskb) { + ret = tls_push(sk, record_type); + if (ret) + goto sendpage_end; + set_bit(SOCK_NOSPACE, + &sk->sk_socket->flags); + ret = sk_stream_wait_memory(sk, &timeo); + if (ret) + goto sendpage_end; + + tskb = alloc_skb(0, sk->sk_allocation); + } + + if (skb) + skb->next = tskb; + else + __skb_queue_tail(&ctx->tx_queue, + tskb); + skb = tskb; + i = 0; + } + } else { + skb = alloc_skb(0, sk->sk_allocation); + __skb_queue_tail(&ctx->tx_queue, skb); + i = 0; + } + + get_page(page); + skb_fill_page_desc(skb, i, page, offset, send_size); + skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG; + +coalesced: + skb->len += send_size; + skb->data_len += send_size; + skb->truesize += send_size; + sk->sk_wmem_queued += send_size; + ctx->wmem_len += send_size; + sk_mem_charge(sk, send_size); + ctx->unsent += send_size; + queued += send_size; + offset += queued; + size -= send_size; + + if (eor || ctx->unsent >= TLS_MAX_PAYLOAD_SIZE) { + ret = tls_push(sk, record_type); + if (ret) + goto sendpage_end; + } + } + + if (eor || ctx->unsent >= TLS_MAX_PAYLOAD_SIZE) + ret = tls_push(sk, record_type); + +sendpage_end: + ret = sk_stream_error(sk, flags, ret); + + if (ret < 0) + ret = sk_stream_error(sk, flags, ret); + + release_sock(sk); + + return ret < 0 ? ret : queued; +} -- 2.7.4