From: Ard Biesheuvel <ardb@xxxxxxxxxx> In preparation for optimizing the CCM core asm code using permutation vectors and overlapping loads and stores, ensure that inputs shorter than the size of a AES block are passed via a buffer on the stack, in a way that positions the data at the end of a 16 byte buffer. This removes the need for the asm code to reason about a rare corner case where the tail of the data cannot be read/written using a single NEON load/store instruction. While at it, tweak the copyright header and authorship to bring it up to date. Signed-off-by: Ard Biesheuvel <ardb@xxxxxxxxxx> --- arch/arm64/crypto/aes-ce-ccm-glue.c | 57 ++++++++++++++------ 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/arch/arm64/crypto/aes-ce-ccm-glue.c b/arch/arm64/crypto/aes-ce-ccm-glue.c index b177ebea7d09..2f4e6a318fcd 100644 --- a/arch/arm64/crypto/aes-ce-ccm-glue.c +++ b/arch/arm64/crypto/aes-ce-ccm-glue.c @@ -1,8 +1,11 @@ // SPDX-License-Identifier: GPL-2.0-only /* - * aes-ccm-glue.c - AES-CCM transform for ARMv8 with Crypto Extensions + * aes-ce-ccm-glue.c - AES-CCM transform for ARMv8 with Crypto Extensions * - * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@xxxxxxxxxx> + * Copyright (C) 2013 - 2017 Linaro Ltd. + * Copyright (C) 2024 Google LLC + * + * Author: Ard Biesheuvel <ardb@xxxxxxxxxx> */ #include <asm/neon.h> @@ -149,7 +152,7 @@ static int ccm_encrypt(struct aead_request *req) struct crypto_aes_ctx *ctx = crypto_aead_ctx(aead); struct skcipher_walk walk; u8 __aligned(8) mac[AES_BLOCK_SIZE]; - u8 buf[AES_BLOCK_SIZE]; + u8 orig_iv[AES_BLOCK_SIZE]; u32 len = req->cryptlen; int err; @@ -158,7 +161,7 @@ static int ccm_encrypt(struct aead_request *req) return err; /* preserve the original iv for the final round */ - memcpy(buf, req->iv, AES_BLOCK_SIZE); + memcpy(orig_iv, req->iv, AES_BLOCK_SIZE); err = skcipher_walk_aead_encrypt(&walk, req, false); if (unlikely(err)) @@ -171,16 +174,26 @@ static int ccm_encrypt(struct aead_request *req) do { u32 tail = walk.nbytes % AES_BLOCK_SIZE; + const u8 *src = walk.src.virt.addr; + u8 *dst = walk.dst.virt.addr; + u8 buf[AES_BLOCK_SIZE]; if (walk.nbytes == walk.total) tail = 0; - ce_aes_ccm_encrypt(walk.dst.virt.addr, walk.src.virt.addr, - walk.nbytes - tail, ctx->key_enc, - num_rounds(ctx), mac, walk.iv); + if (unlikely(walk.total < AES_BLOCK_SIZE)) + src = dst = memcpy(buf + sizeof(buf) - walk.total, + src, walk.total); + + ce_aes_ccm_encrypt(dst, src, walk.nbytes - tail, + ctx->key_enc, num_rounds(ctx), + mac, walk.iv); + + if (unlikely(walk.total < AES_BLOCK_SIZE)) + memcpy(walk.dst.virt.addr, dst, walk.total); if (walk.nbytes == walk.total) - ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx)); + ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx)); if (walk.nbytes) { err = skcipher_walk_done(&walk, tail); @@ -206,7 +219,7 @@ static int ccm_decrypt(struct aead_request *req) unsigned int authsize = crypto_aead_authsize(aead); struct skcipher_walk walk; u8 __aligned(8) mac[AES_BLOCK_SIZE]; - u8 buf[AES_BLOCK_SIZE]; + u8 orig_iv[AES_BLOCK_SIZE]; u32 len = req->cryptlen - authsize; int err; @@ -215,7 +228,7 @@ static int ccm_decrypt(struct aead_request *req) return err; /* preserve the original iv for the final round */ - memcpy(buf, req->iv, AES_BLOCK_SIZE); + memcpy(orig_iv, req->iv, AES_BLOCK_SIZE); err = skcipher_walk_aead_decrypt(&walk, req, false); if (unlikely(err)) @@ -228,16 +241,26 @@ static int ccm_decrypt(struct aead_request *req) do { u32 tail = walk.nbytes % AES_BLOCK_SIZE; + const u8 *src = walk.src.virt.addr; + u8 *dst = walk.dst.virt.addr; + u8 buf[AES_BLOCK_SIZE]; if (walk.nbytes == walk.total) tail = 0; - ce_aes_ccm_decrypt(walk.dst.virt.addr, walk.src.virt.addr, - walk.nbytes - tail, ctx->key_enc, - num_rounds(ctx), mac, walk.iv); + if (unlikely(walk.total < AES_BLOCK_SIZE)) + src = dst = memcpy(buf + sizeof(buf) - walk.total, + src, walk.total); + + ce_aes_ccm_decrypt(dst, src, walk.nbytes - tail, + ctx->key_enc, num_rounds(ctx), + mac, walk.iv); + + if (unlikely(walk.total < AES_BLOCK_SIZE)) + memcpy(walk.dst.virt.addr, dst, walk.total); if (walk.nbytes == walk.total) - ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx)); + ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx)); if (walk.nbytes) { err = skcipher_walk_done(&walk, tail); @@ -250,11 +273,11 @@ static int ccm_decrypt(struct aead_request *req) return err; /* compare calculated auth tag with the stored one */ - scatterwalk_map_and_copy(buf, req->src, + scatterwalk_map_and_copy(orig_iv, req->src, req->assoclen + req->cryptlen - authsize, authsize, 0); - if (crypto_memneq(mac, buf, authsize)) + if (crypto_memneq(mac, orig_iv, authsize)) return -EBADMSG; return 0; } @@ -293,6 +316,6 @@ module_init(aes_mod_init); module_exit(aes_mod_exit); MODULE_DESCRIPTION("Synchronous AES in CCM mode using ARMv8 Crypto Extensions"); -MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@xxxxxxxxxx>"); +MODULE_AUTHOR("Ard Biesheuvel <ardb@xxxxxxxxxx>"); MODULE_LICENSE("GPL v2"); MODULE_ALIAS_CRYPTO("ccm(aes)"); -- 2.43.0.275.g3460e3d667-goog