diff --git a/arch/arm64/crypto/sm4-ce-gcm-glue.c b/arch/arm64/crypto/sm4-ce-gcm-glue.c
index c450a2025ca9..73bfb6972d3a 100644
--- a/arch/arm64/crypto/sm4-ce-gcm-glue.c
+++ b/arch/arm64/crypto/sm4-ce-gcm-glue.c
@@ -135,22 +135,23 @@ static void gcm_calculate_auth_mac(struct aead_request *req, u8 ghash[])
}
static int gcm_crypt(struct aead_request *req, struct skcipher_walk *walk,
- struct sm4_gcm_ctx *ctx, u8 ghash[],
+ u8 ghash[], int err,
void (*sm4_ce_pmull_gcm_crypt)(const u32 *rkey_enc,
u8 *dst, const u8 *src, u8 *iv,
unsigned int nbytes, u8 *ghash,
const u8 *ghash_table, const u8 *lengths))
{
+ struct crypto_aead *aead = crypto_aead_reqtfm(req);
+ struct sm4_gcm_ctx *ctx = crypto_aead_ctx(aead);
u8 __aligned(8) iv[SM4_BLOCK_SIZE];
be128 __aligned(8) lengths;
- int err;
memset(ghash, 0, SM4_BLOCK_SIZE);
lengths.a = cpu_to_be64(req->assoclen * 8);
lengths.b = cpu_to_be64(walk->total * 8);
- memcpy(iv, walk->iv, GCM_IV_SIZE);
+ memcpy(iv, req->iv, GCM_IV_SIZE);
put_unaligned_be32(2, iv + GCM_IV_SIZE);
kernel_neon_begin();
@@ -158,49 +159,51 @@ static int gcm_crypt(struct aead_request *req, struct skcipher_walk *walk,
if (req->assoclen)
gcm_calculate_auth_mac(req, ghash);
- do {
+ while (walk->nbytes) {
unsigned int tail = walk->nbytes % SM4_BLOCK_SIZE;
const u8 *src = walk->src.virt.addr;
u8 *dst = walk->dst.virt.addr;
if (walk->nbytes == walk->total) {
- tail = 0;
-
sm4_ce_pmull_gcm_crypt(ctx->key.rkey_enc, dst, src, iv,
walk->nbytes, ghash,
ctx->ghash_table,
(const u8 *)&lengths);
- } else if (walk->nbytes - tail) {
- sm4_ce_pmull_gcm_crypt(ctx->key.rkey_enc, dst, src, iv,
- walk->nbytes - tail, ghash,
- ctx->ghash_table, NULL);
+
+ kernel_neon_end();
+
+ return skcipher_walk_done(walk, 0);
}
+ sm4_ce_pmull_gcm_crypt(ctx->key.rkey_enc, dst, src, iv,
+ walk->nbytes - tail, ghash,
+ ctx->ghash_table, NULL);
+
kernel_neon_end();
err = skcipher_walk_done(walk, tail);
- if (err)
- return err;
- if (walk->nbytes)
- kernel_neon_begin();
- } while (walk->nbytes > 0);
- return 0;
+ kernel_neon_begin();
+ }
+
+ sm4_ce_pmull_gcm_crypt(ctx->key.rkey_enc, NULL, NULL, iv,
+ walk->nbytes, ghash, ctx->ghash_table,
+ (const u8 *)&lengths);
+
+ kernel_neon_end();
+
+ return err;
}
static int gcm_encrypt(struct aead_request *req)
{
struct crypto_aead *aead = crypto_aead_reqtfm(req);
- struct sm4_gcm_ctx *ctx = crypto_aead_ctx(aead);
u8 __aligned(8) ghash[SM4_BLOCK_SIZE];
struct skcipher_walk walk;
int err;
err = skcipher_walk_aead_encrypt(&walk, req, false);
- if (err)
- return err;
-
- err = gcm_crypt(req, &walk, ctx, ghash, sm4_ce_pmull_gcm_enc);
+ err = gcm_crypt(req, &walk, ghash, err, sm4_ce_pmull_gcm_enc);
if (err)
return err;
@@ -215,17 +218,13 @@ static int gcm_decrypt(struct aead_request *req)
{
struct crypto_aead *aead = crypto_aead_reqtfm(req);
unsigned int authsize = crypto_aead_authsize(aead);
- struct sm4_gcm_ctx *ctx = crypto_aead_ctx(aead);
u8 __aligned(8) ghash[SM4_BLOCK_SIZE];
u8 authtag[SM4_BLOCK_SIZE];
struct skcipher_walk walk;
int err;
err = skcipher_walk_aead_decrypt(&walk, req, false);
- if (err)
- return err;
-
- err = gcm_crypt(req, &walk, ctx, ghash, sm4_ce_pmull_gcm_dec);
+ err = gcm_crypt(req, &walk, ghash, err, sm4_ce_pmull_gcm_dec);
if (err)
return err;