This is a quick write-up about AES ECB 128 bits simple python implementation. As usual, the goal is understanding the specs not the math behind it(I am too old for this) nor the efficient implementation.
The NIST 1 specification defines AES algorithm as follows. The Algorithm uses ADDROUNDKEY, SUBBYTES, SHIFTROWS,MIXCOLUMNS operations to mix the key with message block.
> Algorithm 1 Pseudocode for CIPHER()
> 1: procedure CIPHER(in, Nr, w)
> 2: state ← in . See Sec. 3.4
> 3: state ← ADDROUNDKEY(state,w[0..3]) . See Sec. 5.1.4
> 4: for round from 1 to Nr −1 do
> 5: state ← SUBBYTES(state) . See Sec. 5.1.1
> 6: state ← SHIFTROWS(state) . See Sec. 5.1.2
> 7: state ← MIXCOLUMNS(state) . See Sec. 5.1.3
> 8: state ← ADDROUNDKEY(state,w[4 ∗ round..4 ∗ round +3])
> 9: end for
> 10: state ← SUBBYTES(state)
> 11: state ← SHIFTROWS(state)
> 12: state ← ADDROUNDKEY(state,w[4 ∗Nr..4 ∗Nr +3])
> 13: return state . See Sec. 3.4
> 14: end procedure
Key expansion Link to heading
Basically, the KEYEXPANSION operation expands the key from 16B (in 128b AES) to block of 4x4x4 bytes. It also uses s-box transaction and Rcon
which is called the round constant.
KEYEXPANSION() is a routine that is applied to the key to generate 4 ∗ (Nr +1) words. Thus, four words are generated for each of the Nr +1 applications of ADDROUNDKEY() within the specifcation of CIPHER(), as described in Section 5.1.4. The output of the routine consists of a linear array of words, denoted by w[i], where i is in the range 0 ≤ i < 4 ∗ (Nr +1). KEYEXPANSION() invokes 10 fxed words denoted by Rcon[ j] for 1 ≤ j ≤ 10. These 10 words are called the round constants. For AES-128, a distinct round constant is called in the generation of each of the 10 round keys. For AES-192 and AES-256, the key expansion routine calls the frst eight and seven of these same constants, respectively.
def key_expansion(key, Nk, Nb):
def sub_word(word):
return bytes(S_BOX[b] for b in word)
def rot_word(word):
return word[1:] + word[:1]
Rcon = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36]
w = [key[4 * i:4 * i + 4] for i in range(Nk)]
for i in range(Nk, 4 * (Nr + 1)):
temp = w[i - 1]
if i % Nk == 0:
temp = bytes(a ^ b for a, b in zip(sub_word(rot_word(temp)), [Rcon[i // Nk - 1], 0, 0, 0]))
elif Nk > 6 and i % Nk == 4:
temp = sub_word(temp)
w.append(bytes(a ^ b for a, b in zip(w[i - Nk], temp)))
return w
key round Link to heading
The key round xor key with each column and to generate the new state
ADDROUNDKEY() is a transformation of the state in which a round key is combined with the state by applying the bitwise XOR operation. In particular, each round key consists of four words from the key schedule (described in Section 5.2), each of which is combined with a column of the state as follows:
0 0 0 0 [s0,c,s1,c,s2,c,s3,c] = [s0,c,s1,c,s2,c,s3,c]⊕[w(4∗round+c)] for 0 ≤ c < 4
def add_round_key(state, round_key):
new_state = bytearray(state)
for i in range(4):
for j in range(4):
new_state[i * 4 + j] ^= round_key[i][j]
return bytes(new_state)
s-box Link to heading
From the spec, SUBBYTES (or s-box) is define as follows:
SUBBYTES() is an invertible, non-linear transformation of the state in which a substitution table, called an S-box, is applied independently to each byte in the state. The AES S-box is denoted by SBOX(). Let b denote an input byte to SBOX(), and let c denote the constant byte {01100011}. The output byte b0 = SBOX(b) is constructed by composing the following two transformations:
The operation uses the hex representation to index the right element in long array in the s-box matrix(i have another post about to generate that matrix according to spec). for example, 0x01 with index row 0 and column 1 which is element 0x7c.
S_BOX = (
0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16
)
def sub_bytes(state):
return bytes(S_BOX[b] for b in state)
shift row Link to heading
The shift row operation works on block to shift around rows
SHIFTROWS() is a transformation of the state in which the bytes in the last three rows of the state are cyclically shifted. The number of positions by which the bytes are shifted depends on the row index r, as follows:0
s = sr,(c+r) mod 4 for 0 ≤ r < 4 and 0 ≤ c < 4.
The implementation, uses module modulo 4 to calculate the new row location for each element in the block.
def shift_rows(state):
state = list(state)
new_state = state[:]
for i in range(4):
for j in range(4):
new_state[j * 4 + i] = state[((j + i) % 4) * 4 + i]
return bytes(new_state)
Column mix Link to heading
The spec defines MIX could as follows:
MIXCOLUMNS() is a transformation of the state that multiplies each of the four columns of the state by a single fxed matrix, as described in Section 4.3, with its entries taken from the following word: [a0,a1,a2,a3] = [{02},{01},{01},{03}].
s0,c = ({02} • s0,c)⊕({03} • s1,c)⊕s2,c ⊕s3,c0 s1,c = s0,c ⊕({02} • s1,c)⊕({03} • s2,c)⊕s3,c0 s2,c = s0,c ⊕s1,c ⊕({02} • s2,c)⊕({03} • s3,c) 2,c0 s3,c = ({03} • s0,c)⊕s1,c ⊕s2,c ⊕({02} • s3,c).
Also it defines xtimes
which also know as GF(2)
as
? The modular reduction by m(x) may be applied to intermediate steps in the calculation of b(x)c(x); ? consequently, it is useful to consider the special case that c(x) = x (i.e., c = {02}). In particular, ? the product b • {02} can be expressed as a function of b, denoted by XTIMES(b), as follows: ? ( ? {b6 b5 b4 b3 b2 b1 b0 0} if b7 = 0 ? XTIMES(b) = {b6 b5 b4 b3 b2 b1 b0 0} ⊕ {0 0 0 1 1 0 1 1} if b7 = 1.
mix_single_column
implements that matrix multiplication exactly. It just have to divide GF(3)
into GF(2) xor GF(1)
as described in the spec for XTIMES
def mix_columns(state):
def xtimes(a):
return (((a << 1) ^ 0x1b) & 0xff) if (a & 0x80) else (a << 1)
def mix_single_column(a):
return [
xtimes(a[0]) ^ xtimes(a[1]) ^ a[1] ^ a[2] ^ a[3],
a[0] ^ xtimes(a[1]) ^ xtimes(a[2]) ^ a[2] ^ a[3],
a[0] ^ a[1] ^ xtimes(a[2]) ^ xtimes(a[3]) ^ a[3],
xtimes(a[0]) ^ a[0] ^ a[1] ^ a[2] ^ xtimes(a[3])
]
state = list(state)
for i in range(4):
column = state[i*4:(i+1)*4]
mixed_column = mix_single_column(column)
state[i*4:(i+1)*4] = mixed_column
return bytes(state)
Full implementation Link to heading
the following code should generate for the 2 example (one from AES implementation and other from spec).
my_aes = gfp6wzvTH3lN5TO2B37yWQ==
e : ['0x39', '0x25', '0x84', '0x1d', '0x2', '0xdc', '0x9', '0xfb', '0xdc', '0x11', '0x85', '0x97', '0x19', '0x6a', '0xb', '0x32']
my_aes = OSWEHQLcCfvcEYWXGWoLMg==
import base64
import logging
S_BOX = (
0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16
)
Nr = 10
Nb = 4
Nk = 4
BS = 16
def pad(s):
return s + (BS - len(s) % BS) * chr(BS - len(s) % BS).encode()
def unpad(s):
return s[:-ord(s[len(s)-1:])]
def key_expansion(key, Nk, Nb):
def sub_word(word):
return bytes(S_BOX[b] for b in word)
def rot_word(word):
return word[1:] + word[:1]
Rcon = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36]
w = [key[4 * i:4 * i + 4] for i in range(Nk)]
for i in range(Nk, 4 * (Nr + 1)):
temp = w[i - 1]
if i % Nk == 0:
temp = bytes(a ^ b for a, b in zip(sub_word(rot_word(temp)), [Rcon[i // Nk - 1], 0, 0, 0]))
elif Nk > 6 and i % Nk == 4:
temp = sub_word(temp)
w.append(bytes(a ^ b for a, b in zip(w[i - Nk], temp)))
return w
def add_round_key(state, round_key):
new_state = bytearray(state)
for i in range(4):
for j in range(4):
new_state[i * 4 + j] ^= round_key[i][j]
return bytes(new_state)
def sub_bytes(state):
return bytes(S_BOX[b] for b in state)
def shift_rows(state):
state = list(state)
new_state = state[:]
for i in range(4):
for j in range(4):
new_state[j * 4 + i] = state[((j + i) % 4) * 4 + i]
return bytes(new_state)
def mix_columns(state):
def xtimes(a):
return (((a << 1) ^ 0x1b) & 0xff) if (a & 0x80) else (a << 1)
def mix_single_column(a):
return [
xtimes(a[0]) ^ xtimes(a[1]) ^ a[1] ^ a[2] ^ a[3],
a[0] ^ xtimes(a[1]) ^ xtimes(a[2]) ^ a[2] ^ a[3],
a[0] ^ a[1] ^ xtimes(a[2]) ^ xtimes(a[3]) ^ a[3],
xtimes(a[0]) ^ a[0] ^ a[1] ^ a[2] ^ xtimes(a[3])
]
state = list(state)
for i in range(4):
column = state[i*4:(i+1)*4]
mixed_column = mix_single_column(column)
state[i*4:(i+1)*4] = mixed_column
return bytes(state)
def my_aes(message, key):
state = message
w = key_expansion(key, Nk, Nb)
# print(f"State before rounds: {[hex(b) for b in state]}")
state = add_round_key(state, w[0:4])
for round in range(1, Nr):
logging.debug(f"Round {round}: {[hex(b) for b in state]}")
state = sub_bytes(state)
logging.debug(f"sub_bytes {round}: {[hex(b) for b in state]}")
state = shift_rows(state)
logging.debug(f"shift_rows {round}: {[hex(b) for b in state]}")
state = mix_columns(state)
logging.debug(f"mix_columns {round}: {[hex(b) for b in state]}")
state = add_round_key(state, w[4 * round:4 * round + 4])
logging.debug(f"add_round_key {round}: {[hex(b) for b in state]}")
state = sub_bytes(state)
state = shift_rows(state)
state = add_round_key(state, w[4 * Nr:4 * Nr + 4])
logging.debug(f"State before rounds: {[hex(b) for b in state]}")
return state
# From https://gist.github.com/sachadee/dc06cc0b7747578e578a470ade0a5f2d
# should generate gfp6wzvTH3lN5TO2B37yWQ==
message = 'I love Medium'
key = 'AAAAAAAAAAAAAAAA'
message = pad(message.encode('utf-8'))
key = key.encode('utf-8')
e = my_aes(message, key)
print(f"my_aes = {base64.b64encode(e).decode('utf-8', 'ignore')}")
# This test cipher from https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.197-upd1.pdf
# should generate the following state at the end of the last round
# 39 02 dc 19
# 25 dc 11 6a
# 84 09 85 0b
# 1d fb 97 32
message = b'\x32\x43\xf6\xa8\x88\x5a\x30\x8d\x31\x31\x98\xa2\xe0\x37\x07\x34'
key = b'\x2b\x7e\x15\x16\x28\xae\xd2\xa6\xab\xf7\x15\x88\x09\xcf\x4f\x3c'
e = my_aes(message, key)
print(f"e : {[hex(b) for b in e]}")
print(f"my_aes = {base64.b64encode(e).decode('utf-8', 'ignore')}")