[PATCH V4 2/4] aes: add support of aes192 and aes256

Philippe Reynes philippe.reynes at softathome.com
Mon Jan 6 15:22:35 CET 2020


Until now, we only support aes128. This commit add the support
of aes192 and aes256.

Signed-off-by: Philippe Reynes <philippe.reynes at softathome.com>
---
 arch/arm/mach-tegra/tegra20/crypto.c | 41 ++++++++++---------
 cmd/aes.c                            | 38 ++++++++++++------
 include/uboot_aes.h                  | 34 ++++++++++------
 lib/aes.c                            | 77 +++++++++++++++++++++++++-----------
 4 files changed, 125 insertions(+), 65 deletions(-)

Changelog:
v4:
- update the tge driver crypto for tegra20 to use the new aes api.
  (squash previous patch 3, feedback from Simon)
v3:
- no change
v2:
- fix the help for the aes command


diff --git a/arch/arm/mach-tegra/tegra20/crypto.c b/arch/arm/mach-tegra/tegra20/crypto.c
index 66fbc3b..b91191e 100644
--- a/arch/arm/mach-tegra/tegra20/crypto.c
+++ b/arch/arm/mach-tegra/tegra20/crypto.c
@@ -39,34 +39,35 @@ static void left_shift_vector(u8 *in, u8 *out, int size)
 /**
  * Sign a block of data, putting the result into dst.
  *
- * \param key			Input AES key, length AES_KEY_LENGTH
+ * \param key			Input AES key, length AES128_KEY_LENGTH
  * \param key_schedule		Expanded key to use
  * \param src			Source data of length 'num_aes_blocks' blocks
- * \param dst			Destination buffer, length AES_KEY_LENGTH
+ * \param dst			Destination buffer, length AES128_KEY_LENGTH
  * \param num_aes_blocks	Number of AES blocks to encrypt
  */
 static void sign_object(u8 *key, u8 *key_schedule, u8 *src, u8 *dst,
 			u32 num_aes_blocks)
 {
-	u8 tmp_data[AES_KEY_LENGTH];
-	u8 iv[AES_KEY_LENGTH] = {0};
-	u8 left[AES_KEY_LENGTH];
-	u8 k1[AES_KEY_LENGTH];
+	u8 tmp_data[AES128_KEY_LENGTH];
+	u8 iv[AES128_KEY_LENGTH] = {0};
+	u8 left[AES128_KEY_LENGTH];
+	u8 k1[AES128_KEY_LENGTH];
 	u8 *cbc_chain_data;
 	unsigned i;
 
 	cbc_chain_data = zero_key;	/* Convenient array of 0's for IV */
 
 	/* compute K1 constant needed by AES-CMAC calculation */
-	for (i = 0; i < AES_KEY_LENGTH; i++)
+	for (i = 0; i < AES128_KEY_LENGTH; i++)
 		tmp_data[i] = 0;
 
-	aes_cbc_encrypt_blocks(key_schedule, iv, tmp_data, left, 1);
+	aes_cbc_encrypt_blocks(AES128_KEY_LENGTH, key_schedule, iv,
+			       tmp_data, left, 1);
 
 	left_shift_vector(left, k1, sizeof(left));
 
 	if ((left[0] >> 7) != 0) /* get MSB of L */
-		k1[AES_KEY_LENGTH-1] ^= AES_CMAC_CONST_RB;
+		k1[AES128_KEY_LENGTH - 1] ^= AES_CMAC_CONST_RB;
 
 	/* compute the AES-CMAC value */
 	for (i = 0; i < num_aes_blocks; i++) {
@@ -78,31 +79,32 @@ static void sign_object(u8 *key, u8 *key_schedule, u8 *src, u8 *dst,
 			aes_apply_cbc_chain_data(tmp_data, k1, tmp_data);
 
 		/* encrypt the AES block */
-		aes_encrypt(tmp_data, key_schedule, dst);
+		aes_encrypt(AES128_KEY_LENGTH, tmp_data,
+			    key_schedule, dst);
 
 		debug("sign_obj: block %d of %d\n", i, num_aes_blocks);
 
 		/* Update pointers for next loop. */
 		cbc_chain_data = dst;
-		src += AES_KEY_LENGTH;
+		src += AES128_KEY_LENGTH;
 	}
 }
 
 /**
  * Encrypt and sign a block of data (depending on security mode).
  *
- * \param key		Input AES key, length AES_KEY_LENGTH
+ * \param key		Input AES key, length AES128_KEY_LENGTH
  * \param oper		Security operations mask to perform (enum security_op)
  * \param src		Source data
  * \param length	Size of source data
- * \param sig_dst	Destination address for signature, AES_KEY_LENGTH bytes
+ * \param sig_dst	Destination address for signature, AES128_KEY_LENGTH bytes
  */
 static int encrypt_and_sign(u8 *key, enum security_op oper, u8 *src,
 			    u32 length, u8 *sig_dst)
 {
 	u32 num_aes_blocks;
-	u8 key_schedule[AES_EXPAND_KEY_LENGTH];
-	u8 iv[AES_KEY_LENGTH] = {0};
+	u8 key_schedule[AES128_EXPAND_KEY_LENGTH];
+	u8 iv[AES128_KEY_LENGTH] = {0};
 
 	debug("encrypt_and_sign: length = %d\n", length);
 
@@ -110,15 +112,16 @@ static int encrypt_and_sign(u8 *key, enum security_op oper, u8 *src,
 	 * The only need for a key is for signing/checksum purposes, so
 	 * if not encrypting, expand a key of 0s.
 	 */
-	aes_expand_key(oper & SECURITY_ENCRYPT ? key : zero_key, key_schedule);
+	aes_expand_key(oper & SECURITY_ENCRYPT ? key : zero_key,
+		       AES128_KEY_LENGTH, key_schedule);
 
-	num_aes_blocks = (length + AES_KEY_LENGTH - 1) / AES_KEY_LENGTH;
+	num_aes_blocks = (length + AES128_KEY_LENGTH - 1) / AES128_KEY_LENGTH;
 
 	if (oper & SECURITY_ENCRYPT) {
 		/* Perform this in place, resulting in src being encrypted. */
 		debug("encrypt_and_sign: begin encryption\n");
-		aes_cbc_encrypt_blocks(key_schedule, iv, src, src,
-				       num_aes_blocks);
+		aes_cbc_encrypt_blocks(AES128_KEY_LENGTH, key_schedule, iv, src,
+				       src, num_aes_blocks);
 		debug("encrypt_and_sign: end encryption\n");
 	}
 
diff --git a/cmd/aes.c b/cmd/aes.c
index 24b0256..8c5b42f 100644
--- a/cmd/aes.c
+++ b/cmd/aes.c
@@ -2,7 +2,7 @@
 /*
  * Copyright (C) 2014 Marek Vasut <marex at denx.de>
  *
- * Command for en/de-crypting block of memory with AES-128-CBC cipher.
+ * Command for en/de-crypting block of memory with AES-[128/192/256]-CBC cipher.
  */
 
 #include <common.h>
@@ -13,6 +13,18 @@
 #include <linux/compiler.h>
 #include <mapmem.h>
 
+u32 aes_get_key_len(char *command)
+{
+	u32 key_len = AES128_KEY_LENGTH;
+
+	if (!strcmp(command, "aes.192"))
+		key_len = AES192_KEY_LENGTH;
+	else if (!strcmp(command, "aes.256"))
+		key_len = AES256_KEY_LENGTH;
+
+	return key_len;
+}
+
 /**
  * do_aes() - Handle the "aes" command-line command
  * @cmdtp:	Command data struct pointer
@@ -27,13 +39,15 @@ static int do_aes(cmd_tbl_t *cmdtp, int flag, int argc, char *const argv[])
 {
 	uint32_t key_addr, iv_addr, src_addr, dst_addr, len;
 	uint8_t *key_ptr, *iv_ptr, *src_ptr, *dst_ptr;
-	uint8_t key_exp[AES_EXPAND_KEY_LENGTH];
-	uint32_t aes_blocks;
+	u8 key_exp[AES256_EXPAND_KEY_LENGTH];
+	u32 aes_blocks, key_len;
 	int enc;
 
 	if (argc != 7)
 		return CMD_RET_USAGE;
 
+	key_len = aes_get_key_len(argv[0]);
+
 	if (!strncmp(argv[1], "enc", 3))
 		enc = 1;
 	else if (!strncmp(argv[1], "dec", 3))
@@ -47,23 +61,23 @@ static int do_aes(cmd_tbl_t *cmdtp, int flag, int argc, char *const argv[])
 	dst_addr = simple_strtoul(argv[5], NULL, 16);
 	len = simple_strtoul(argv[6], NULL, 16);
 
-	key_ptr = (uint8_t *)map_sysmem(key_addr, 128 / 8);
+	key_ptr = (uint8_t *)map_sysmem(key_addr, key_len);
 	iv_ptr = (uint8_t *)map_sysmem(iv_addr, 128 / 8);
 	src_ptr = (uint8_t *)map_sysmem(src_addr, len);
 	dst_ptr = (uint8_t *)map_sysmem(dst_addr, len);
 
 	/* First we expand the key. */
-	aes_expand_key(key_ptr, key_exp);
+	aes_expand_key(key_ptr, key_len, key_exp);
 
 	/* Calculate the number of AES blocks to encrypt. */
 	aes_blocks = DIV_ROUND_UP(len, AES_BLOCK_LENGTH);
 
 	if (enc)
-		aes_cbc_encrypt_blocks(key_exp, iv_ptr, src_ptr, dst_ptr,
-				       aes_blocks);
+		aes_cbc_encrypt_blocks(key_len, key_exp, iv_ptr, src_ptr,
+				       dst_ptr, aes_blocks);
 	else
-		aes_cbc_decrypt_blocks(key_exp, iv_ptr, src_ptr, dst_ptr,
-				       aes_blocks);
+		aes_cbc_decrypt_blocks(key_len, key_exp, iv_ptr, src_ptr,
+				       dst_ptr, aes_blocks);
 
 	unmap_sysmem(key_ptr);
 	unmap_sysmem(iv_ptr);
@@ -76,13 +90,13 @@ static int do_aes(cmd_tbl_t *cmdtp, int flag, int argc, char *const argv[])
 /***************************************************/
 #ifdef CONFIG_SYS_LONGHELP
 static char aes_help_text[] =
-	"enc key iv src dst len - Encrypt block of data $len bytes long\n"
+	"[.128,.192,.256] enc key iv src dst len - Encrypt block of data $len bytes long\n"
 	"                             at address $src using a key at address\n"
 	"                             $key with initialization vector at address\n"
 	"                             $iv. Store the result at address $dst.\n"
 	"                             The $len size must be multiple of 16 bytes.\n"
 	"                             The $key and $iv must be 16 bytes long.\n"
-	"aes dec key iv src dst len - Decrypt block of data $len bytes long\n"
+	"aes [.128,.192,.256] dec key iv src dst len - Decrypt block of data $len bytes long\n"
 	"                             at address $src using a key at address\n"
 	"                             $key with initialization vector at address\n"
 	"                             $iv. Store the result at address $dst.\n"
@@ -92,6 +106,6 @@ static char aes_help_text[] =
 
 U_BOOT_CMD(
 	aes, 7, 1, do_aes,
-	"AES 128 CBC encryption",
+	"AES 128/192/256 CBC encryption",
 	aes_help_text
 );
diff --git a/include/uboot_aes.h b/include/uboot_aes.h
index 1ae3ac9..d2583be 100644
--- a/include/uboot_aes.h
+++ b/include/uboot_aes.h
@@ -23,11 +23,18 @@ typedef unsigned int u32;
 
 enum {
 	AES_STATECOLS	= 4,	/* columns in the state & expanded key */
-	AES_KEYCOLS	= 4,	/* columns in a key */
-	AES_ROUNDS	= 10,	/* rounds in encryption */
-
-	AES_KEY_LENGTH	= 128 / 8,
-	AES_EXPAND_KEY_LENGTH	= 4 * AES_STATECOLS * (AES_ROUNDS + 1),
+	AES128_KEYCOLS	= 4,	/* columns in a key for aes128 */
+	AES192_KEYCOLS	= 6,	/* columns in a key for aes128 */
+	AES256_KEYCOLS	= 8,	/* columns in a key for aes128 */
+	AES128_ROUNDS	= 10,	/* rounds in encryption for aes128 */
+	AES192_ROUNDS	= 12,	/* rounds in encryption for aes192 */
+	AES256_ROUNDS	= 14,	/* rounds in encryption for aes256 */
+	AES128_KEY_LENGTH	= 128 / 8,
+	AES192_KEY_LENGTH	= 192 / 8,
+	AES256_KEY_LENGTH	= 256 / 8,
+	AES128_EXPAND_KEY_LENGTH = 4 * AES_STATECOLS * (AES128_ROUNDS + 1),
+	AES192_EXPAND_KEY_LENGTH = 4 * AES_STATECOLS * (AES192_ROUNDS + 1),
+	AES256_EXPAND_KEY_LENGTH = 4 * AES_STATECOLS * (AES256_ROUNDS + 1),
 	AES_BLOCK_LENGTH	= 128 / 8,
 };
 
@@ -37,28 +44,31 @@ enum {
  * Expand a key into a key schedule, which is then used for the other
  * operations.
  *
- * @key		Key, of length AES_KEY_LENGTH bytes
+ * @key		Key
+ * @key_size	Size of the key (in bits)
  * @expkey	Buffer to place expanded key, AES_EXPAND_KEY_LENGTH
  */
-void aes_expand_key(u8 *key, u8 *expkey);
+void aes_expand_key(u8 *key, u32 key_size, u8 *expkey);
 
 /**
  * aes_encrypt() - Encrypt single block of data with AES 128
  *
+ * @key_size	Size of the aes key (in bits)
  * @in		Input data
  * @expkey	Expanded key to use for encryption (from aes_expand_key())
  * @out		Output data
  */
-void aes_encrypt(u8 *in, u8 *expkey, u8 *out);
+void aes_encrypt(u32 key_size, u8 *in, u8 *expkey, u8 *out);
 
 /**
  * aes_decrypt() - Decrypt single block of data with AES 128
  *
+ * @key_size	Size of the aes key (in bits)
  * @in		Input data
  * @expkey	Expanded key to use for decryption (from aes_expand_key())
  * @out		Output data
  */
-void aes_decrypt(u8 *in, u8 *expkey, u8 *out);
+void aes_decrypt(u32 key_size, u8 *in, u8 *expkey, u8 *out);
 
 /**
  * Apply chain data to the destination using EOR
@@ -74,25 +84,27 @@ void aes_apply_cbc_chain_data(u8 *cbc_chain_data, u8 *src, u8 *dst);
 /**
  * aes_cbc_encrypt_blocks() - Encrypt multiple blocks of data with AES CBC.
  *
+ * @key_size		Size of the aes key (in bits)
  * @key_exp		Expanded key to use
  * @iv			Initialization vector
  * @src			Source data to encrypt
  * @dst			Destination buffer
  * @num_aes_blocks	Number of AES blocks to encrypt
  */
-void aes_cbc_encrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
+void aes_cbc_encrypt_blocks(u32 key_size, u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
 			    u32 num_aes_blocks);
 
 /**
  * Decrypt multiple blocks of data with AES CBC.
  *
+ * @key_size		Size of the aes key (in bits)
  * @key_exp		Expanded key to use
  * @iv			Initialization vector
  * @src			Source data to decrypt
  * @dst			Destination buffer
  * @num_aes_blocks	Number of AES blocks to decrypt
  */
-void aes_cbc_decrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
+void aes_cbc_decrypt_blocks(u32 key_size, u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
 			    u32 num_aes_blocks);
 
 #endif /* _AES_REF_H_ */
diff --git a/lib/aes.c b/lib/aes.c
index cfa57b6..ce53c9f 100644
--- a/lib/aes.c
+++ b/lib/aes.c
@@ -508,50 +508,79 @@ static u8 rcon[11] = {
 	0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36
 };
 
+static u32 aes_get_rounds(u32 key_len)
+{
+	u32 rounds = AES128_ROUNDS;
+
+	if (key_len == AES192_KEY_LENGTH)
+		rounds = AES192_ROUNDS;
+	else if (key_len == AES256_KEY_LENGTH)
+		rounds = AES256_ROUNDS;
+
+	return rounds;
+}
+
+static u32 aes_get_keycols(u32 key_len)
+{
+	u32 keycols = AES128_KEYCOLS;
+
+	if (key_len == AES192_KEY_LENGTH)
+		keycols = AES192_KEYCOLS;
+	else if (key_len == AES256_KEY_LENGTH)
+		keycols = AES256_KEYCOLS;
+
+	return keycols;
+}
+
 /* produce AES_STATECOLS bytes for each round */
-void aes_expand_key(u8 *key, u8 *expkey)
+void aes_expand_key(u8 *key, u32 key_len, u8 *expkey)
 {
 	u8 tmp0, tmp1, tmp2, tmp3, tmp4;
-	u32 idx;
+	u32 idx, aes_rounds, aes_keycols;
 
-	memcpy(expkey, key, AES_KEYCOLS * 4);
+	aes_rounds = aes_get_rounds(key_len);
+	aes_keycols = aes_get_keycols(key_len);
 
-	for (idx = AES_KEYCOLS; idx < AES_STATECOLS * (AES_ROUNDS + 1); idx++) {
+	memcpy(expkey, key, key_len);
+
+	for (idx = aes_keycols; idx < AES_STATECOLS * (aes_rounds + 1); idx++) {
 		tmp0 = expkey[4*idx - 4];
 		tmp1 = expkey[4*idx - 3];
 		tmp2 = expkey[4*idx - 2];
 		tmp3 = expkey[4*idx - 1];
-		if (!(idx % AES_KEYCOLS)) {
+		if (!(idx % aes_keycols)) {
 			tmp4 = tmp3;
 			tmp3 = sbox[tmp0];
-			tmp0 = sbox[tmp1] ^ rcon[idx / AES_KEYCOLS];
+			tmp0 = sbox[tmp1] ^ rcon[idx / aes_keycols];
 			tmp1 = sbox[tmp2];
 			tmp2 = sbox[tmp4];
-		} else if ((AES_KEYCOLS > 6) && (idx % AES_KEYCOLS == 4)) {
+		} else if ((aes_keycols > 6) && (idx % aes_keycols == 4)) {
 			tmp0 = sbox[tmp0];
 			tmp1 = sbox[tmp1];
 			tmp2 = sbox[tmp2];
 			tmp3 = sbox[tmp3];
 		}
 
-		expkey[4*idx+0] = expkey[4*idx - 4*AES_KEYCOLS + 0] ^ tmp0;
-		expkey[4*idx+1] = expkey[4*idx - 4*AES_KEYCOLS + 1] ^ tmp1;
-		expkey[4*idx+2] = expkey[4*idx - 4*AES_KEYCOLS + 2] ^ tmp2;
-		expkey[4*idx+3] = expkey[4*idx - 4*AES_KEYCOLS + 3] ^ tmp3;
+		expkey[4*idx+0] = expkey[4*idx - 4*aes_keycols + 0] ^ tmp0;
+		expkey[4*idx+1] = expkey[4*idx - 4*aes_keycols + 1] ^ tmp1;
+		expkey[4*idx+2] = expkey[4*idx - 4*aes_keycols + 2] ^ tmp2;
+		expkey[4*idx+3] = expkey[4*idx - 4*aes_keycols + 3] ^ tmp3;
 	}
 }
 
 /* encrypt one 128 bit block */
-void aes_encrypt(u8 *in, u8 *expkey, u8 *out)
+void aes_encrypt(u32 key_len, u8 *in, u8 *expkey, u8 *out)
 {
 	u8 state[AES_STATECOLS * 4];
-	u32 round;
+	u32 round, aes_rounds;
+
+	aes_rounds = aes_get_rounds(key_len);
 
 	memcpy(state, in, AES_STATECOLS * 4);
 	add_round_key((u32 *)state, (u32 *)expkey);
 
-	for (round = 1; round < AES_ROUNDS + 1; round++) {
-		if (round < AES_ROUNDS)
+	for (round = 1; round < aes_rounds + 1; round++) {
+		if (round < aes_rounds)
 			mix_sub_columns(state);
 		else
 			shift_rows(state);
@@ -563,18 +592,20 @@ void aes_encrypt(u8 *in, u8 *expkey, u8 *out)
 	memcpy(out, state, sizeof(state));
 }
 
-void aes_decrypt(u8 *in, u8 *expkey, u8 *out)
+void aes_decrypt(u32 key_len, u8 *in, u8 *expkey, u8 *out)
 {
 	u8 state[AES_STATECOLS * 4];
-	int round;
+	int round, aes_rounds;
+
+	aes_rounds = aes_get_rounds(key_len);
 
 	memcpy(state, in, sizeof(state));
 
 	add_round_key((u32 *)state,
-		      (u32 *)expkey + AES_ROUNDS * AES_STATECOLS);
+		      (u32 *)expkey + aes_rounds * AES_STATECOLS);
 	inv_shift_rows(state);
 
-	for (round = AES_ROUNDS; round--; ) {
+	for (round = aes_rounds; round--; ) {
 		add_round_key((u32 *)state,
 			      (u32 *)expkey + round * AES_STATECOLS);
 		if (round)
@@ -600,7 +631,7 @@ void aes_apply_cbc_chain_data(u8 *cbc_chain_data, u8 *src, u8 *dst)
 		*dst++ = *src++ ^ *cbc_chain_data++;
 }
 
-void aes_cbc_encrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
+void aes_cbc_encrypt_blocks(u32 key_len, u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
 			    u32 num_aes_blocks)
 {
 	u8 tmp_data[AES_BLOCK_LENGTH];
@@ -616,7 +647,7 @@ void aes_cbc_encrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
 		debug_print_vector("AES Xor", AES_BLOCK_LENGTH, tmp_data);
 
 		/* Encrypt the AES block */
-		aes_encrypt(tmp_data, key_exp, dst);
+		aes_encrypt(key_len, tmp_data, key_exp, dst);
 		debug_print_vector("AES Dst", AES_BLOCK_LENGTH, dst);
 
 		/* Update pointers for next loop. */
@@ -626,7 +657,7 @@ void aes_cbc_encrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
 	}
 }
 
-void aes_cbc_decrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
+void aes_cbc_decrypt_blocks(u32 key_len, u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
 			    u32 num_aes_blocks)
 {
 	u8 tmp_data[AES_BLOCK_LENGTH], tmp_block[AES_BLOCK_LENGTH];
@@ -642,7 +673,7 @@ void aes_cbc_decrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
 		memcpy(tmp_block, src, AES_BLOCK_LENGTH);
 
 		/* Decrypt the AES block */
-		aes_decrypt(src, key_exp, tmp_data);
+		aes_decrypt(key_len, src, key_exp, tmp_data);
 		debug_print_vector("AES Xor", AES_BLOCK_LENGTH, tmp_data);
 
 		/* Apply the chain data */
-- 
2.7.4



More information about the U-Boot mailing list