On Nov 28, 2023, at 12:07, Eric Biggers <ebiggers@xxxxxxxxxx> wrote: > On Mon, Nov 27, 2023 at 03:06:57PM +0800, Jerry Shih wrote: >> +typedef void (*aes_xts_func)(const u8 *in, u8 *out, size_t length, >> + const struct crypto_aes_ctx *key, u8 *iv, >> + int update_iv); > > There's no need for this indirection, because the function pointer can only have > one value. > > Note also that when Control Flow Integrity is enabled, assembly functions can > only be called indirectly when they use SYM_TYPED_FUNC_START. That's another > reason to avoid indirect calls that aren't actually necessary. We have two function pointers for encryption and decryption. static int xts_encrypt(struct skcipher_request *req) { return xts_crypt(req, rv64i_zvbb_zvkg_zvkned_aes_xts_encrypt); } static int xts_decrypt(struct skcipher_request *req) { return xts_crypt(req, rv64i_zvbb_zvkg_zvkned_aes_xts_decrypt); } The enc and dec path could be folded together into `xts_crypt()`, but we will have additional branches for enc/decryption path if we don't want to have the indirect calls. Use `SYM_TYPED_FUNC_START` in asm might be better. >> + nbytes &= (~(AES_BLOCK_SIZE - 1)); > > Expressions like ~(n - 1) should not have another set of parentheses around them Fixed. >> +static int xts_crypt(struct skcipher_request *req, aes_xts_func func) >> +{ >> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); >> + const struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm); >> + struct skcipher_request sub_req; >> + struct scatterlist sg_src[2], sg_dst[2]; >> + struct scatterlist *src, *dst; >> + struct skcipher_walk walk; >> + unsigned int walk_size = crypto_skcipher_walksize(tfm); >> + unsigned int tail_bytes; >> + unsigned int head_bytes; >> + unsigned int nbytes; >> + unsigned int update_iv = 1; >> + int err; >> + >> + /* xts input size should be bigger than AES_BLOCK_SIZE */ >> + if (req->cryptlen < AES_BLOCK_SIZE) >> + return -EINVAL; >> + >> + /* >> + * We split xts-aes cryption into `head` and `tail` parts. >> + * The head block contains the input from the beginning which doesn't need >> + * `ciphertext stealing` method. >> + * The tail block contains at least two AES blocks including ciphertext >> + * stealing data from the end. >> + */ >> + if (req->cryptlen <= walk_size) { >> + /* >> + * All data is in one `walk`. We could handle it within one AES-XTS call in >> + * the end. >> + */ >> + tail_bytes = req->cryptlen; >> + head_bytes = 0; >> + } else { >> + if (req->cryptlen & (AES_BLOCK_SIZE - 1)) { >> + /* >> + * with ciphertext stealing >> + * >> + * Find the largest tail size which is small than `walk` size while the >> + * head part still fits AES block boundary. >> + */ >> + tail_bytes = req->cryptlen & (AES_BLOCK_SIZE - 1); >> + tail_bytes = walk_size + tail_bytes - AES_BLOCK_SIZE; >> + head_bytes = req->cryptlen - tail_bytes; >> + } else { >> + /* no ciphertext stealing */ >> + tail_bytes = 0; >> + head_bytes = req->cryptlen; >> + } >> + } >> + >> + riscv64_aes_encrypt_zvkned(&ctx->ctx2, req->iv, req->iv); >> + >> + if (head_bytes && tail_bytes) { >> + /* If we have to parts, setup new request for head part only. */ >> + skcipher_request_set_tfm(&sub_req, tfm); >> + skcipher_request_set_callback( >> + &sub_req, skcipher_request_flags(req), NULL, NULL); >> + skcipher_request_set_crypt(&sub_req, req->src, req->dst, >> + head_bytes, req->iv); >> + req = &sub_req; >> + } >> + >> + if (head_bytes) { >> + err = skcipher_walk_virt(&walk, req, false); >> + while ((nbytes = walk.nbytes)) { >> + if (nbytes == walk.total) >> + update_iv = (tail_bytes > 0); >> + >> + nbytes &= (~(AES_BLOCK_SIZE - 1)); >> + kernel_vector_begin(); >> + func(walk.src.virt.addr, walk.dst.virt.addr, nbytes, >> + &ctx->ctx1, req->iv, update_iv); >> + kernel_vector_end(); >> + >> + err = skcipher_walk_done(&walk, walk.nbytes - nbytes); >> + } >> + if (err || !tail_bytes) >> + return err; >> + >> + /* >> + * Setup new request for tail part. >> + * We use `scatterwalk_next()` to find the next scatterlist from last >> + * walk instead of iterating from the beginning. >> + */ >> + dst = src = scatterwalk_next(sg_src, &walk.in); >> + if (req->dst != req->src) >> + dst = scatterwalk_next(sg_dst, &walk.out); >> + skcipher_request_set_crypt(req, src, dst, tail_bytes, req->iv); >> + } >> + >> + /* tail */ >> + err = skcipher_walk_virt(&walk, req, false); >> + if (err) >> + return err; >> + if (walk.nbytes != tail_bytes) >> + return -EINVAL; >> + kernel_vector_begin(); >> + func(walk.src.virt.addr, walk.dst.virt.addr, walk.nbytes, &ctx->ctx1, >> + req->iv, 0); >> + kernel_vector_end(); >> + >> + return skcipher_walk_done(&walk, 0); >> +} > > Did you consider writing xts_crypt() the way that arm64 and x86 do it? The > above seems to reinvent sort of the same thing from first principles. I'm > wondering if you should just copy the existing approach for now. Then there > would be no need to add the scatterwalk_next() function, and also the handling > of inputs that don't need ciphertext stealing would be a bit more streamlined. I will check the arm and x86's implementations. But the `scatterwalk_next()` proposed in this series does the same thing as the call `scatterwalk_ffwd()` in arm and x86's implementations. The scatterwalk_ffwd() iterates from the beginning of scatterlist(O(n)), but the scatterwalk_next() is just iterates from the end point of the last used scatterlist(O(1)). >> +static int __init riscv64_aes_block_mod_init(void) >> +{ >> + int ret = -ENODEV; >> + >> + if (riscv_isa_extension_available(NULL, ZVKNED) && >> + riscv_vector_vlen() >= 128 && riscv_vector_vlen() <= 2048) { >> + ret = simd_register_skciphers_compat( >> + riscv64_aes_algs_zvkned, >> + ARRAY_SIZE(riscv64_aes_algs_zvkned), >> + riscv64_aes_simd_algs_zvkned); >> + if (ret) >> + return ret; >> + >> + if (riscv_isa_extension_available(NULL, ZVBB)) { >> + ret = simd_register_skciphers_compat( >> + riscv64_aes_alg_zvkned_zvkb, >> + ARRAY_SIZE(riscv64_aes_alg_zvkned_zvkb), >> + riscv64_aes_simd_alg_zvkned_zvkb); >> + if (ret) >> + goto unregister_zvkned; > > This makes the registration of the zvkned-zvkb algorithm conditional on zvbb, > not zvkb. Shouldn't the extension checks actually look like: > > ZVKNED > ZVKB > ZVBB && ZVKG Fixed. But we will have the conditions like: if(ZVKNED) { reg_cipher_1(); if(ZVKB) { reg_cipher_2(); } if (ZVBB && ZVKG) { reg_cipher_3(); } } > - Eric