To improve performance on cores with deep pipelines such as ThunderX2, reimplement gcm(aes) using a 4-way interleave rather than the 2-way interleave we use currently. This comes down to a complete rewrite of the GCM part of the combined GCM/GHASH driver, and instead of interleaving two invocations of AES with the GHASH handling at the instruction level, the new version uses a more coarse grained approach where each chunk of 64 bytes is encrypted first and then ghashed (or ghashed and then decrypted in the converse case). The core NEON routine is now able to consume inputs of any size, and tail blocks of less than 64 bytes are handled using overlapping loads and stores, and processed by the same 4-way encryption and hashing routines. This gets rid of most of the branches, and avoids having to return to the C code to handle the tail block using a stack buffer. The table below compares the performance of the old driver and the new one on various micro-architectures and running in various modes. | AES-128 | AES-192 | AES-256 | #bytes | 512 | 1500 | 4k | 512 | 1500 | 4k | 512 | 1500 | 4k | -------+-----+------+-----+-----+------+-----+-----+------+-----+ TX2 | 35% | 23% | 11% | 34% | 20% | 9% | 38% | 25% | 16% | EMAG | 11% | 6% | 3% | 12% | 4% | 2% | 11% | 4% | 2% | A72 | 8% | 5% | -4% | 9% | 4% | -5% | 7% | 4% | -5% | A53 | 11% | 6% | -1% | 10% | 8% | -1% | 10% | 8% | -2% | Signed-off-by: Ard Biesheuvel <ard.biesheuvel@xxxxxxxxxx> --- arch/arm64/crypto/ghash-ce-core.S | 501 ++++++++++++++------ arch/arm64/crypto/ghash-ce-glue.c | 293 +++++------- 2 files changed, 467 insertions(+), 327 deletions(-) diff --git a/arch/arm64/crypto/ghash-ce-core.S b/arch/arm64/crypto/ghash-ce-core.S index 410e8afcf5a7..a791c4adf8e6 100644 --- a/arch/arm64/crypto/ghash-ce-core.S +++ b/arch/arm64/crypto/ghash-ce-core.S @@ -13,8 +13,8 @@ T1 .req v2 T2 .req v3 MASK .req v4 - XL .req v5 - XM .req v6 + XM .req v5 + XL .req v6 XH .req v7 IN1 .req v7 @@ -358,20 +358,37 @@ ENTRY(pmull_ghash_update_p8) __pmull_ghash p8 ENDPROC(pmull_ghash_update_p8) - KS0 .req v12 - KS1 .req v13 - INP0 .req v14 - INP1 .req v15 - - .macro load_round_keys, rounds, rk - cmp \rounds, #12 - blo 2222f /* 128 bits */ - beq 1111f /* 192 bits */ - ld1 {v17.4s-v18.4s}, [\rk], #32 -1111: ld1 {v19.4s-v20.4s}, [\rk], #32 -2222: ld1 {v21.4s-v24.4s}, [\rk], #64 - ld1 {v25.4s-v28.4s}, [\rk], #64 - ld1 {v29.4s-v31.4s}, [\rk] + KS0 .req v8 + KS1 .req v9 + KS2 .req v10 + KS3 .req v11 + + INP0 .req v21 + INP1 .req v22 + INP2 .req v23 + INP3 .req v24 + + K0 .req v25 + K1 .req v26 + K2 .req v27 + K3 .req v28 + K4 .req v12 + K5 .req v13 + K6 .req v4 + K7 .req v5 + K8 .req v14 + K9 .req v15 + KK .req v29 + KL .req v30 + KM .req v31 + + .macro load_round_keys, rounds, rk, tmp + add \tmp, \rk, #64 + ld1 {K0.4s-K3.4s}, [\rk] + ld1 {K4.4s-K5.4s}, [\tmp] + add \tmp, \rk, \rounds, lsl #4 + sub \tmp, \tmp, #32 + ld1 {KK.4s-KM.4s}, [\tmp] .endm .macro enc_round, state, key @@ -379,197 +396,367 @@ ENDPROC(pmull_ghash_update_p8) aesmc \state\().16b, \state\().16b .endm - .macro enc_block, state, rounds - cmp \rounds, #12 - b.lo 2222f /* 128 bits */ - b.eq 1111f /* 192 bits */ - enc_round \state, v17 - enc_round \state, v18 -1111: enc_round \state, v19 - enc_round \state, v20 -2222: .irp key, v21, v22, v23, v24, v25, v26, v27, v28, v29 + .macro enc_qround, s0, s1, s2, s3, key + enc_round \s0, \key + enc_round \s1, \key + enc_round \s2, \key + enc_round \s3, \key + .endm + + .macro enc_block, state, rounds, rk, tmp + add \tmp, \rk, #96 + ld1 {K6.4s-K7.4s}, [\tmp], #32 + .irp key, K0, K1, K2, K3, K4 K5 enc_round \state, \key .endr - aese \state\().16b, v30.16b - eor \state\().16b, \state\().16b, v31.16b + + tbnz \rounds, #2, .Lnot128_\@ +.Lout256_\@: + enc_round \state, K6 + enc_round \state, K7 + +.Lout192_\@: + enc_round \state, KK + aese \state\().16b, KL.16b + eor \state\().16b, \state\().16b, KM.16b + + .subsection 1 +.Lnot128_\@: + ld1 {K8.4s-K9.4s}, [\tmp], #32 + enc_round \state, K6 + enc_round \state, K7 + ld1 {K6.4s-K7.4s}, [\tmp] + enc_round \state, K8 + enc_round \state, K9 + tbz \rounds, #1, .Lout192_\@ + b .Lout256_\@ + .previous .endm + .align 6 .macro pmull_gcm_do_crypt, enc - ld1 {SHASH.2d}, [x4], #16 - ld1 {HH.2d}, [x4] - ld1 {XL.2d}, [x1] - ldr x8, [x5, #8] // load lower counter + stp x29, x30, [sp, #-32]! + mov x29, sp + str x19, [sp, #24] + + load_round_keys x7, x6, x8 + + ld1 {SHASH.2d}, [x3], #16 + ld1 {HH.2d-HH4.2d}, [x3] - movi MASK.16b, #0xe1 trn1 SHASH2.2d, SHASH.2d, HH.2d trn2 T1.2d, SHASH.2d, HH.2d -CPU_LE( rev x8, x8 ) - shl MASK.2d, MASK.2d, #57 eor SHASH2.16b, SHASH2.16b, T1.16b - .if \enc == 1 - ldr x10, [sp] - ld1 {KS0.16b-KS1.16b}, [x10] - .endif + trn1 HH34.2d, HH3.2d, HH4.2d + trn2 T1.2d, HH3.2d, HH4.2d + eor HH34.16b, HH34.16b, T1.16b - cbnz x6, 4f + ld1 {XL.2d}, [x4] -0: ld1 {INP0.16b-INP1.16b}, [x3], #32 + cbz x0, 3f // tag only? - rev x9, x8 - add x11, x8, #1 - add x8, x8, #2 + ldr w8, [x5, #12] // load lower counter +CPU_LE( rev w8, w8 ) - .if \enc == 1 - eor INP0.16b, INP0.16b, KS0.16b // encrypt input - eor INP1.16b, INP1.16b, KS1.16b +0: mov w9, #4 // max blocks per round + add x10, x0, #0xf + lsr x10, x10, #4 // remaining blocks + + subs x0, x0, #64 + csel w9, w10, w9, mi + add w8, w8, w9 + + bmi 1f + ld1 {INP0.16b-INP3.16b}, [x2], #64 + .subsection 1 + /* + * Populate the four input registers right to left with up to 63 bytes + * of data, using overlapping loads to avoid branches. + * + * INP0 INP1 INP2 INP3 + * 1 byte | | | |x | + * 16 bytes | | | |xxxxxxxx| + * 17 bytes | | |xxxxxxxx|x | + * 47 bytes | |xxxxxxxx|xxxxxxxx|xxxxxxx | + * etc etc + * + * Note that this code may read up to 15 bytes before the start of + * the input. It is up to the calling code to ensure this is safe if + * this happens in the first iteration of the loop (i.e., when the + * input size is < 16 bytes) + */ +1: mov x15, #16 + ands x19, x0, #0xf + csel x19, x19, x15, ne + adr_l x17, .Lpermute_table + 16 + + sub x11, x15, x19 + add x12, x17, x11 + sub x17, x17, x11 + ld1 {T1.16b}, [x12] + sub x10, x1, x11 + sub x11, x2, x11 + + cmp x0, #-16 + csel x14, x15, xzr, gt + cmp x0, #-32 + csel x15, x15, xzr, gt + cmp x0, #-48 + csel x16, x19, xzr, gt + csel x1, x1, x10, gt + csel x2, x2, x11, gt + + ld1 {INP0.16b}, [x2], x14 + ld1 {INP1.16b}, [x2], x15 + ld1 {INP2.16b}, [x2], x16 + ld1 {INP3.16b}, [x2] + tbl INP3.16b, {INP3.16b}, T1.16b + b 2f + .previous + +2: .if \enc == 0 + bl pmull_gcm_ghash_4x .endif - ld1 {KS0.8b}, [x5] // load upper counter - rev x11, x11 - sub w0, w0, #2 - mov KS1.8b, KS0.8b - ins KS0.d[1], x9 // set lower counter - ins KS1.d[1], x11 + bl pmull_gcm_enc_4x - rev64 T1.16b, INP1.16b + tbnz x0, #63, 6f + st1 {INP0.16b-INP3.16b}, [x1], #64 + .if \enc == 1 + bl pmull_gcm_ghash_4x + .endif + bne 0b - cmp w7, #12 - b.ge 2f // AES-192/256? +3: ldp x19, x10, [sp, #24] + cbz x10, 5f // output tag? -1: enc_round KS0, v21 - ext IN1.16b, T1.16b, T1.16b, #8 + ld1 {INP3.16b}, [x10] // load lengths[] + mov w9, #1 + bl pmull_gcm_ghash_4x - enc_round KS1, v21 - pmull2 XH2.1q, SHASH.2d, IN1.2d // a1 * b1 + mov w11, #(0x1 << 24) // BE '1U' + ld1 {KS0.16b}, [x5] + mov KS0.s[3], w11 - enc_round KS0, v22 - eor T1.16b, T1.16b, IN1.16b + enc_block KS0, x7, x6, x12 - enc_round KS1, v22 - pmull XL2.1q, SHASH.1d, IN1.1d // a0 * b0 + ext XL.16b, XL.16b, XL.16b, #8 + rev64 XL.16b, XL.16b + eor XL.16b, XL.16b, KS0.16b + st1 {XL.16b}, [x10] // store tag - enc_round KS0, v23 - pmull XM2.1q, SHASH2.1d, T1.1d // (a1 + a0)(b1 + b0) +4: ldp x29, x30, [sp], #32 + ret - enc_round KS1, v23 - rev64 T1.16b, INP0.16b - ext T2.16b, XL.16b, XL.16b, #8 +5: +CPU_LE( rev w8, w8 ) + str w8, [x5, #12] // store lower counter + st1 {XL.2d}, [x4] + b 4b + +6: ld1 {T1.16b-T2.16b}, [x17], #32 // permute vectors + sub x17, x17, x19, lsl #1 + + cmp w9, #1 + beq 7f + .subsection 1 +7: ld1 {INP2.16b}, [x1] + tbx INP2.16b, {INP3.16b}, T1.16b + mov INP3.16b, INP2.16b + b 8f + .previous + + st1 {INP0.16b}, [x1], x14 + st1 {INP1.16b}, [x1], x15 + st1 {INP2.16b}, [x1], x16 + tbl INP3.16b, {INP3.16b}, T1.16b + tbx INP3.16b, {INP2.16b}, T2.16b +8: st1 {INP3.16b}, [x1] - enc_round KS0, v24 - ext IN1.16b, T1.16b, T1.16b, #8 - eor T1.16b, T1.16b, T2.16b + .if \enc == 1 + ld1 {T1.16b}, [x17] + tbl INP3.16b, {INP3.16b}, T1.16b // clear non-data bits + bl pmull_gcm_ghash_4x + .endif + b 3b + .endm - enc_round KS1, v24 - eor XL.16b, XL.16b, IN1.16b + /* + * void pmull_gcm_encrypt(int blocks, u8 dst[], const u8 src[], + * struct ghash_key const *k, u64 dg[], u8 ctr[], + * int rounds, u8 tag) + */ +ENTRY(pmull_gcm_encrypt) + pmull_gcm_do_crypt 1 +ENDPROC(pmull_gcm_encrypt) - enc_round KS0, v25 - eor T1.16b, T1.16b, XL.16b + /* + * void pmull_gcm_decrypt(int blocks, u8 dst[], const u8 src[], + * struct ghash_key const *k, u64 dg[], u8 ctr[], + * int rounds, u8 tag) + */ +ENTRY(pmull_gcm_decrypt) + pmull_gcm_do_crypt 0 +ENDPROC(pmull_gcm_decrypt) - enc_round KS1, v25 - pmull2 XH.1q, HH.2d, XL.2d // a1 * b1 +pmull_gcm_ghash_4x: + movi MASK.16b, #0xe1 + shl MASK.2d, MASK.2d, #57 - enc_round KS0, v26 - pmull XL.1q, HH.1d, XL.1d // a0 * b0 + rev64 T1.16b, INP0.16b + rev64 T2.16b, INP1.16b + rev64 TT3.16b, INP2.16b + rev64 TT4.16b, INP3.16b - enc_round KS1, v26 - pmull2 XM.1q, SHASH2.2d, T1.2d // (a1 + a0)(b1 + b0) + ext XL.16b, XL.16b, XL.16b, #8 - enc_round KS0, v27 - eor XL.16b, XL.16b, XL2.16b - eor XH.16b, XH.16b, XH2.16b + tbz w9, #2, 0f // <4 blocks? + .subsection 1 +0: movi XH2.16b, #0 + movi XM2.16b, #0 + movi XL2.16b, #0 - enc_round KS1, v27 - eor XM.16b, XM.16b, XM2.16b - ext T1.16b, XL.16b, XH.16b, #8 + tbz w9, #0, 1f // 2 blocks? + tbz w9, #1, 2f // 1 block? - enc_round KS0, v28 - eor T2.16b, XL.16b, XH.16b - eor XM.16b, XM.16b, T1.16b + eor T2.16b, T2.16b, XL.16b + ext T1.16b, T2.16b, T2.16b, #8 + b .Lgh3 - enc_round KS1, v28 - eor XM.16b, XM.16b, T2.16b +1: eor TT3.16b, TT3.16b, XL.16b + ext T2.16b, TT3.16b, TT3.16b, #8 + b .Lgh2 - enc_round KS0, v29 - pmull T2.1q, XL.1d, MASK.1d +2: eor TT4.16b, TT4.16b, XL.16b + ext IN1.16b, TT4.16b, TT4.16b, #8 + b .Lgh1 + .previous - enc_round KS1, v29 - mov XH.d[0], XM.d[1] - mov XM.d[1], XL.d[0] + eor T1.16b, T1.16b, XL.16b + ext IN1.16b, T1.16b, T1.16b, #8 - aese KS0.16b, v30.16b - eor XL.16b, XM.16b, T2.16b + pmull2 XH2.1q, HH4.2d, IN1.2d // a1 * b1 + eor T1.16b, T1.16b, IN1.16b + pmull XL2.1q, HH4.1d, IN1.1d // a0 * b0 + pmull2 XM2.1q, HH34.2d, T1.2d // (a1 + a0)(b1 + b0) - aese KS1.16b, v30.16b - ext T2.16b, XL.16b, XL.16b, #8 + ext T1.16b, T2.16b, T2.16b, #8 +.Lgh3: eor T2.16b, T2.16b, T1.16b + pmull2 XH.1q, HH3.2d, T1.2d // a1 * b1 + pmull XL.1q, HH3.1d, T1.1d // a0 * b0 + pmull XM.1q, HH34.1d, T2.1d // (a1 + a0)(b1 + b0) - eor KS0.16b, KS0.16b, v31.16b - pmull XL.1q, XL.1d, MASK.1d - eor T2.16b, T2.16b, XH.16b + eor XH2.16b, XH2.16b, XH.16b + eor XL2.16b, XL2.16b, XL.16b + eor XM2.16b, XM2.16b, XM.16b - eor KS1.16b, KS1.16b, v31.16b - eor XL.16b, XL.16b, T2.16b + ext T2.16b, TT3.16b, TT3.16b, #8 +.Lgh2: eor TT3.16b, TT3.16b, T2.16b + pmull2 XH.1q, HH.2d, T2.2d // a1 * b1 + pmull XL.1q, HH.1d, T2.1d // a0 * b0 + pmull2 XM.1q, SHASH2.2d, TT3.2d // (a1 + a0)(b1 + b0) - .if \enc == 0 - eor INP0.16b, INP0.16b, KS0.16b - eor INP1.16b, INP1.16b, KS1.16b - .endif + eor XH2.16b, XH2.16b, XH.16b + eor XL2.16b, XL2.16b, XL.16b + eor XM2.16b, XM2.16b, XM.16b - st1 {INP0.16b-INP1.16b}, [x2], #32 + ext IN1.16b, TT4.16b, TT4.16b, #8 +.Lgh1: eor TT4.16b, TT4.16b, IN1.16b + pmull XL.1q, SHASH.1d, IN1.1d // a0 * b0 + pmull2 XH.1q, SHASH.2d, IN1.2d // a1 * b1 + pmull XM.1q, SHASH2.1d, TT4.1d // (a1 + a0)(b1 + b0) - cbnz w0, 0b + eor XH.16b, XH.16b, XH2.16b + eor XL.16b, XL.16b, XL2.16b + eor XM.16b, XM.16b, XM2.16b -CPU_LE( rev x8, x8 ) - st1 {XL.2d}, [x1] - str x8, [x5, #8] // store lower counter + eor T2.16b, XL.16b, XH.16b + ext T1.16b, XL.16b, XH.16b, #8 + eor XM.16b, XM.16b, T2.16b - .if \enc == 1 - st1 {KS0.16b-KS1.16b}, [x10] - .endif + __pmull_reduce_p64 + + eor T2.16b, T2.16b, XH.16b + eor XL.16b, XL.16b, T2.16b ret +ENDPROC(pmull_gcm_ghash_4x) + +pmull_gcm_enc_4x: + ld1 {KS0.16b}, [x5] // load upper counter + sub w10, w8, #4 + sub w11, w8, #3 + sub w12, w8, #2 + sub w13, w8, #1 + rev w10, w10 + rev w11, w11 + rev w12, w12 + rev w13, w13 + mov KS1.16b, KS0.16b + mov KS2.16b, KS0.16b + mov KS3.16b, KS0.16b + ins KS0.s[3], w10 // set lower counter + ins KS1.s[3], w11 + ins KS2.s[3], w12 + ins KS3.s[3], w13 + + add x10, x6, #96 // round key pointer + ld1 {K6.4s-K7.4s}, [x10], #32 + .irp key, K0, K1, K2, K3, K4, K5 + enc_qround KS0, KS1, KS2, KS3, \key + .endr -2: b.eq 3f // AES-192? - enc_round KS0, v17 - enc_round KS1, v17 - enc_round KS0, v18 - enc_round KS1, v18 -3: enc_round KS0, v19 - enc_round KS1, v19 - enc_round KS0, v20 - enc_round KS1, v20 - b 1b + tbnz x7, #2, .Lnot128 + .subsection 1 +.Lnot128: + ld1 {K8.4s-K9.4s}, [x10], #32 + .irp key, K6, K7 + enc_qround KS0, KS1, KS2, KS3, \key + .endr + ld1 {K6.4s-K7.4s}, [x10] + .irp key, K8, K9 + enc_qround KS0, KS1, KS2, KS3, \key + .endr + tbz x7, #1, .Lout192 + b .Lout256 + .previous -4: load_round_keys w7, x6 - b 0b - .endm +.Lout256: + .irp key, K6, K7 + enc_qround KS0, KS1, KS2, KS3, \key + .endr - /* - * void pmull_gcm_encrypt(int blocks, u64 dg[], u8 dst[], const u8 src[], - * struct ghash_key const *k, u8 ctr[], - * int rounds, u8 ks[]) - */ -ENTRY(pmull_gcm_encrypt) - pmull_gcm_do_crypt 1 -ENDPROC(pmull_gcm_encrypt) +.Lout192: + enc_qround KS0, KS1, KS2, KS3, KK - /* - * void pmull_gcm_decrypt(int blocks, u64 dg[], u8 dst[], const u8 src[], - * struct ghash_key const *k, u8 ctr[], - * int rounds) - */ -ENTRY(pmull_gcm_decrypt) - pmull_gcm_do_crypt 0 -ENDPROC(pmull_gcm_decrypt) + aese KS0.16b, KL.16b + aese KS1.16b, KL.16b + aese KS2.16b, KL.16b + aese KS3.16b, KL.16b + + eor KS0.16b, KS0.16b, KM.16b + eor KS1.16b, KS1.16b, KM.16b + eor KS2.16b, KS2.16b, KM.16b + eor KS3.16b, KS3.16b, KM.16b + + eor INP0.16b, INP0.16b, KS0.16b + eor INP1.16b, INP1.16b, KS1.16b + eor INP2.16b, INP2.16b, KS2.16b + eor INP3.16b, INP3.16b, KS3.16b - /* - * void pmull_gcm_encrypt_block(u8 dst[], u8 src[], u8 rk[], int rounds) - */ -ENTRY(pmull_gcm_encrypt_block) - cbz x2, 0f - load_round_keys w3, x2 -0: ld1 {v0.16b}, [x1] - enc_block v0, w3 - st1 {v0.16b}, [x0] ret -ENDPROC(pmull_gcm_encrypt_block) +ENDPROC(pmull_gcm_enc_4x) + + .section ".rodata", "a" + .align 6 +.Lpermute_table: + .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + .byte 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7 + .byte 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf + .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + .byte 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7 + .byte 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf + .previous diff --git a/arch/arm64/crypto/ghash-ce-glue.c b/arch/arm64/crypto/ghash-ce-glue.c index 70b1469783f9..522cf004ce65 100644 --- a/arch/arm64/crypto/ghash-ce-glue.c +++ b/arch/arm64/crypto/ghash-ce-glue.c @@ -58,17 +58,15 @@ asmlinkage void pmull_ghash_update_p8(int blocks, u64 dg[], const char *src, struct ghash_key const *k, const char *head); -asmlinkage void pmull_gcm_encrypt(int blocks, u64 dg[], u8 dst[], - const u8 src[], struct ghash_key const *k, +asmlinkage void pmull_gcm_encrypt(int bytes, u8 dst[], const u8 src[], + struct ghash_key const *k, u64 dg[], u8 ctr[], u32 const rk[], int rounds, - u8 ks[]); + u8 tag[]); -asmlinkage void pmull_gcm_decrypt(int blocks, u64 dg[], u8 dst[], - const u8 src[], struct ghash_key const *k, - u8 ctr[], u32 const rk[], int rounds); - -asmlinkage void pmull_gcm_encrypt_block(u8 dst[], u8 const src[], - u32 const rk[], int rounds); +asmlinkage void pmull_gcm_decrypt(int bytes, u8 dst[], const u8 src[], + struct ghash_key const *k, u64 dg[], + u8 ctr[], u32 const rk[], int rounds, + u8 tag[]); static int ghash_init(struct shash_desc *desc) { @@ -85,7 +83,7 @@ static void ghash_do_update(int blocks, u64 dg[], const char *src, struct ghash_key const *k, const char *head)) { - if (likely(crypto_simd_usable())) { + if (likely(crypto_simd_usable() && simd_update)) { kernel_neon_begin(); simd_update(blocks, dg, src, key, head); kernel_neon_end(); @@ -398,136 +396,112 @@ static void gcm_calculate_auth_mac(struct aead_request *req, u64 dg[]) } } -static void gcm_final(struct aead_request *req, struct gcm_aes_ctx *ctx, - u64 dg[], u8 tag[], int cryptlen) -{ - u8 mac[AES_BLOCK_SIZE]; - u128 lengths; - - lengths.a = cpu_to_be64(req->assoclen * 8); - lengths.b = cpu_to_be64(cryptlen * 8); - - ghash_do_update(1, dg, (void *)&lengths, &ctx->ghash_key, NULL, - pmull_ghash_update_p64); - - put_unaligned_be64(dg[1], mac); - put_unaligned_be64(dg[0], mac + 8); - - crypto_xor(tag, mac, AES_BLOCK_SIZE); -} - static int gcm_encrypt(struct aead_request *req) { struct crypto_aead *aead = crypto_aead_reqtfm(req); struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead); + int nrounds = num_rounds(&ctx->aes_key); struct skcipher_walk walk; + u8 buf[AES_BLOCK_SIZE]; u8 iv[AES_BLOCK_SIZE]; - u8 ks[2 * AES_BLOCK_SIZE]; - u8 tag[AES_BLOCK_SIZE]; u64 dg[2] = {}; - int nrounds = num_rounds(&ctx->aes_key); + u128 lengths; + u8 *tag; int err; + lengths.a = cpu_to_be64(req->assoclen * 8); + lengths.b = cpu_to_be64(req->cryptlen * 8); + if (req->assoclen) gcm_calculate_auth_mac(req, dg); memcpy(iv, req->iv, GCM_IV_SIZE); - put_unaligned_be32(1, iv + GCM_IV_SIZE); + put_unaligned_be32(2, iv + GCM_IV_SIZE); err = skcipher_walk_aead_encrypt(&walk, req, false); - if (likely(crypto_simd_usable() && walk.total >= 2 * AES_BLOCK_SIZE)) { - u32 const *rk = NULL; - - kernel_neon_begin(); - pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds); - put_unaligned_be32(2, iv + GCM_IV_SIZE); - pmull_gcm_encrypt_block(ks, iv, NULL, nrounds); - put_unaligned_be32(3, iv + GCM_IV_SIZE); - pmull_gcm_encrypt_block(ks + AES_BLOCK_SIZE, iv, NULL, nrounds); - put_unaligned_be32(4, iv + GCM_IV_SIZE); - + if (likely(crypto_simd_usable())) { do { - int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2; + const u8 *src = walk.src.virt.addr; + u8 *dst = walk.dst.virt.addr; + int nbytes = walk.nbytes; + + tag = (u8 *)&lengths; - if (rk) - kernel_neon_begin(); + if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) { + src = dst = memcpy(buf + sizeof(buf) - nbytes, + src, nbytes); + } else if (nbytes < walk.total) { + nbytes &= ~(AES_BLOCK_SIZE - 1); + tag = NULL; + } - pmull_gcm_encrypt(blocks, dg, walk.dst.virt.addr, - walk.src.virt.addr, &ctx->ghash_key, - iv, rk, nrounds, ks); + kernel_neon_begin(); + pmull_gcm_encrypt(nbytes, dst, src, &ctx->ghash_key, dg, + iv, ctx->aes_key.key_enc, nrounds, + tag); kernel_neon_end(); - err = skcipher_walk_done(&walk, - walk.nbytes % (2 * AES_BLOCK_SIZE)); + if (unlikely(!nbytes)) + break; - rk = ctx->aes_key.key_enc; - } while (walk.nbytes >= 2 * AES_BLOCK_SIZE); - } else { - aes_encrypt(&ctx->aes_key, tag, iv); - put_unaligned_be32(2, iv + GCM_IV_SIZE); + if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) + memcpy(walk.dst.virt.addr, + buf + sizeof(buf) - nbytes, nbytes); - while (walk.nbytes >= (2 * AES_BLOCK_SIZE)) { - const int blocks = - walk.nbytes / (2 * AES_BLOCK_SIZE) * 2; + err = skcipher_walk_done(&walk, walk.nbytes - nbytes); + } while (walk.nbytes); + } else { + while (walk.nbytes >= AES_BLOCK_SIZE) { + int blocks = walk.nbytes / AES_BLOCK_SIZE; + const u8 *src = walk.src.virt.addr; u8 *dst = walk.dst.virt.addr; - u8 *src = walk.src.virt.addr; int remaining = blocks; do { - aes_encrypt(&ctx->aes_key, ks, iv); - crypto_xor_cpy(dst, src, ks, AES_BLOCK_SIZE); + aes_encrypt(&ctx->aes_key, buf, iv); + crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE); crypto_inc(iv, AES_BLOCK_SIZE); dst += AES_BLOCK_SIZE; src += AES_BLOCK_SIZE; } while (--remaining > 0); - ghash_do_update(blocks, dg, - walk.dst.virt.addr, &ctx->ghash_key, - NULL, pmull_ghash_update_p64); + ghash_do_update(blocks, dg, walk.dst.virt.addr, + &ctx->ghash_key, NULL, NULL); err = skcipher_walk_done(&walk, - walk.nbytes % (2 * AES_BLOCK_SIZE)); - } - if (walk.nbytes) { - aes_encrypt(&ctx->aes_key, ks, iv); - if (walk.nbytes > AES_BLOCK_SIZE) { - crypto_inc(iv, AES_BLOCK_SIZE); - aes_encrypt(&ctx->aes_key, ks + AES_BLOCK_SIZE, iv); - } + walk.nbytes % AES_BLOCK_SIZE); } - } - /* handle the tail */ - if (walk.nbytes) { - u8 buf[GHASH_BLOCK_SIZE]; - unsigned int nbytes = walk.nbytes; - u8 *dst = walk.dst.virt.addr; - u8 *head = NULL; + /* handle the tail */ + if (walk.nbytes) { + aes_encrypt(&ctx->aes_key, buf, iv); - crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, ks, - walk.nbytes); + crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, + buf, walk.nbytes); - if (walk.nbytes > GHASH_BLOCK_SIZE) { - head = dst; - dst += GHASH_BLOCK_SIZE; - nbytes %= GHASH_BLOCK_SIZE; + memcpy(buf, walk.dst.virt.addr, walk.nbytes); + memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes); } - memcpy(buf, dst, nbytes); - memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes); - ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head, - pmull_ghash_update_p64); + tag = (u8 *)&lengths; + ghash_do_update(1, dg, tag, &ctx->ghash_key, + walk.nbytes ? buf : NULL, NULL); - err = skcipher_walk_done(&walk, 0); + if (walk.nbytes) + err = skcipher_walk_done(&walk, 0); + + put_unaligned_be64(dg[1], tag); + put_unaligned_be64(dg[0], tag + 8); + put_unaligned_be32(1, iv + GCM_IV_SIZE); + aes_encrypt(&ctx->aes_key, iv, iv); + crypto_xor(tag, iv, AES_BLOCK_SIZE); } if (err) return err; - gcm_final(req, ctx, dg, tag, req->cryptlen); - /* copy authtag to end of dst */ scatterwalk_map_and_copy(tag, req->dst, req->assoclen + req->cryptlen, crypto_aead_authsize(aead), 1); @@ -540,75 +514,65 @@ static int gcm_decrypt(struct aead_request *req) struct crypto_aead *aead = crypto_aead_reqtfm(req); struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead); unsigned int authsize = crypto_aead_authsize(aead); + int nrounds = num_rounds(&ctx->aes_key); struct skcipher_walk walk; - u8 iv[2 * AES_BLOCK_SIZE]; - u8 tag[AES_BLOCK_SIZE]; - u8 buf[2 * GHASH_BLOCK_SIZE]; + u8 buf[AES_BLOCK_SIZE]; + u8 iv[AES_BLOCK_SIZE]; u64 dg[2] = {}; - int nrounds = num_rounds(&ctx->aes_key); + u128 lengths; + u8 *tag; int err; + lengths.a = cpu_to_be64(req->assoclen * 8); + lengths.b = cpu_to_be64((req->cryptlen - authsize) * 8); + if (req->assoclen) gcm_calculate_auth_mac(req, dg); memcpy(iv, req->iv, GCM_IV_SIZE); - put_unaligned_be32(1, iv + GCM_IV_SIZE); + put_unaligned_be32(2, iv + GCM_IV_SIZE); err = skcipher_walk_aead_decrypt(&walk, req, false); - if (likely(crypto_simd_usable() && walk.total >= 2 * AES_BLOCK_SIZE)) { - u32 const *rk = NULL; - - kernel_neon_begin(); - pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds); - put_unaligned_be32(2, iv + GCM_IV_SIZE); - + if (likely(crypto_simd_usable())) { do { - int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2; - int rem = walk.total - blocks * AES_BLOCK_SIZE; - - if (rk) - kernel_neon_begin(); - - pmull_gcm_decrypt(blocks, dg, walk.dst.virt.addr, - walk.src.virt.addr, &ctx->ghash_key, - iv, rk, nrounds); - - /* check if this is the final iteration of the loop */ - if (rem < (2 * AES_BLOCK_SIZE)) { - u8 *iv2 = iv + AES_BLOCK_SIZE; - - if (rem > AES_BLOCK_SIZE) { - memcpy(iv2, iv, AES_BLOCK_SIZE); - crypto_inc(iv2, AES_BLOCK_SIZE); - } + const u8 *src = walk.src.virt.addr; + u8 *dst = walk.dst.virt.addr; + int nbytes = walk.nbytes; - pmull_gcm_encrypt_block(iv, iv, NULL, nrounds); + tag = (u8 *)&lengths; - if (rem > AES_BLOCK_SIZE) - pmull_gcm_encrypt_block(iv2, iv2, NULL, - nrounds); + if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) { + src = dst = memcpy(buf + sizeof(buf) - nbytes, + src, nbytes); + } else if (nbytes < walk.total) { + nbytes &= ~(AES_BLOCK_SIZE - 1); + tag = NULL; } + kernel_neon_begin(); + pmull_gcm_decrypt(nbytes, dst, src, &ctx->ghash_key, dg, + iv, ctx->aes_key.key_enc, nrounds, + tag); kernel_neon_end(); - err = skcipher_walk_done(&walk, - walk.nbytes % (2 * AES_BLOCK_SIZE)); + if (unlikely(!nbytes)) + break; - rk = ctx->aes_key.key_enc; - } while (walk.nbytes >= 2 * AES_BLOCK_SIZE); - } else { - aes_encrypt(&ctx->aes_key, tag, iv); - put_unaligned_be32(2, iv + GCM_IV_SIZE); + if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) + memcpy(walk.dst.virt.addr, + buf + sizeof(buf) - nbytes, nbytes); - while (walk.nbytes >= (2 * AES_BLOCK_SIZE)) { - int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2; + err = skcipher_walk_done(&walk, walk.nbytes - nbytes); + } while (walk.nbytes); + } else { + while (walk.nbytes >= AES_BLOCK_SIZE) { + int blocks = walk.nbytes / AES_BLOCK_SIZE; + const u8 *src = walk.src.virt.addr; u8 *dst = walk.dst.virt.addr; - u8 *src = walk.src.virt.addr; ghash_do_update(blocks, dg, walk.src.virt.addr, - &ctx->ghash_key, NULL, - pmull_ghash_update_p64); + &ctx->ghash_key, NULL, NULL); do { aes_encrypt(&ctx->aes_key, buf, iv); @@ -620,49 +584,38 @@ static int gcm_decrypt(struct aead_request *req) } while (--blocks > 0); err = skcipher_walk_done(&walk, - walk.nbytes % (2 * AES_BLOCK_SIZE)); + walk.nbytes % AES_BLOCK_SIZE); } - if (walk.nbytes) { - if (walk.nbytes > AES_BLOCK_SIZE) { - u8 *iv2 = iv + AES_BLOCK_SIZE; - - memcpy(iv2, iv, AES_BLOCK_SIZE); - crypto_inc(iv2, AES_BLOCK_SIZE); - aes_encrypt(&ctx->aes_key, iv2, iv2); - } - aes_encrypt(&ctx->aes_key, iv, iv); + /* handle the tail */ + if (walk.nbytes) { + memcpy(buf, walk.src.virt.addr, walk.nbytes); + memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes); } - } - /* handle the tail */ - if (walk.nbytes) { - const u8 *src = walk.src.virt.addr; - const u8 *head = NULL; - unsigned int nbytes = walk.nbytes; + tag = (u8 *)&lengths; + ghash_do_update(1, dg, tag, &ctx->ghash_key, + walk.nbytes ? buf : NULL, NULL); - if (walk.nbytes > GHASH_BLOCK_SIZE) { - head = src; - src += GHASH_BLOCK_SIZE; - nbytes %= GHASH_BLOCK_SIZE; - } + if (walk.nbytes) { + aes_encrypt(&ctx->aes_key, buf, iv); - memcpy(buf, src, nbytes); - memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes); - ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head, - pmull_ghash_update_p64); + crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, + buf, walk.nbytes); - crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, iv, - walk.nbytes); + err = skcipher_walk_done(&walk, 0); + } - err = skcipher_walk_done(&walk, 0); + put_unaligned_be64(dg[1], tag); + put_unaligned_be64(dg[0], tag + 8); + put_unaligned_be32(1, iv + GCM_IV_SIZE); + aes_encrypt(&ctx->aes_key, iv, iv); + crypto_xor(tag, iv, AES_BLOCK_SIZE); } if (err) return err; - gcm_final(req, ctx, dg, tag, req->cryptlen - authsize); - /* compare calculated auth tag with the stored one */ scatterwalk_map_and_copy(buf, req->src, req->assoclen + req->cryptlen - authsize, @@ -675,7 +628,7 @@ static int gcm_decrypt(struct aead_request *req) static struct aead_alg gcm_aes_alg = { .ivsize = GCM_IV_SIZE, - .chunksize = 2 * AES_BLOCK_SIZE, + .chunksize = AES_BLOCK_SIZE, .maxauthsize = AES_BLOCK_SIZE, .setkey = gcm_setkey, .setauthsize = gcm_setauthsize, -- 2.17.1