#include #include #include #include typedef uint8_t block_t[4][4]; typedef struct aes_ctx { uint8_t e_key[4 * (14 + 1)][4]; uint32_t Nk; uint32_t Nr; } aes_ctx; // multiply by x modulo x^8 + x^4 + x^3 + x + 1 static uint8_t xtime(uint8_t b) { return (b << 1) ^ ((b & 0x80) ? 0x1b : 0); } // GF(2^8) generic multiplication, double-and-add static uint8_t multiply(uint8_t a, uint8_t b) { uint8_t c = 0; for(size_t i = 0; i < 8; ++i) { if((b >> i) & 1) c ^= a; a = xtime(a); } return c; } static uint8_t rotate_left(uint8_t x, size_t c) { return (x << c % 8) | (x >> (8 - c) % 8); } static uint8_t SubByte(const uint8_t x0) { const uint8_t x1 = multiply( x0, x0); // x^2 const uint8_t x2 = multiply( x1, x0); // x^3 const uint8_t x3 = multiply( x2, x2); // x^6 const uint8_t x4 = multiply( x3, x3); // x^12 const uint8_t x5 = multiply( x4, x2); // x^15 const uint8_t x6 = multiply( x5, x5); // x^30 const uint8_t x7 = multiply( x6, x6); // x^60 const uint8_t x8 = multiply( x7, x2); // x^63 const uint8_t x9 = multiply( x8, x8); // x^126 const uint8_t x10 = multiply( x9, x0); // x^127 const uint8_t x11 = multiply(x10, x10); // x^254 = x^-1 return x11 ^ rotate_left(x11, 1) ^ rotate_left(x11, 2) ^ rotate_left(x11, 3) ^ rotate_left(x11, 4) ^ 0x63; } static void RotWord(uint8_t w[4]) { const uint8_t t = w[0]; w[0] = w[1]; w[1] = w[2]; w[2] = w[3]; w[3] = t; } static void SubWord(uint8_t w[4]) { for(size_t i = 0; i < 4; ++i) w[i] = SubByte(w[i]); } static void AddRoundKey(block_t block, const uint8_t k[][4]) { for(size_t i = 0; i < 4; ++i) for(size_t j = 0; j < 4; ++j) block[i][j] ^= k[i][j]; } static void SubBytes(block_t block) { for(size_t i = 0; i < 4; ++i) for(size_t j = 0; j < 4; ++j) block[i][j] = SubByte(block[i][j]); } static void ShiftRows(block_t block) { for(size_t j = 0; j < 4; ++j) { uint8_t row[4]; for(size_t i = 0; i < 4; ++i) row[i] = block[i][j]; for(size_t i = 0; i < 4; ++i) block[i][j] = row[(i+j)%4]; } } static void MixColumns(block_t block) { for(size_t i = 0; i < 4; ++i) { const uint8_t c0 = block[i][0]; const uint8_t c1 = block[i][1]; const uint8_t c2 = block[i][2]; const uint8_t c3 = block[i][3]; block[i][0] = xtime(c0 ^ c1) ^ c1 ^ c2 ^ c3; block[i][1] = c0 ^ xtime(c1 ^ c2) ^ c2 ^ c3; block[i][2] = c0 ^ c1 ^ xtime(c2 ^ c3) ^ c3; block[i][3] = xtime(c0 ^ c3) ^ c0 ^ c1 ^ c2; } } void aes_set_key(aes_ctx * ctx, uint8_t const * k, size_t keybits) { switch(keybits) { case 128: ctx->Nr = 10; ctx->Nk = 4; break; case 192: ctx->Nr = 12; ctx->Nk = 6; break; case 256: ctx->Nr = 14; ctx->Nk = 8; break; default: abort(); } uint8_t rcon = 1; for(size_t i = 0; i < ctx->Nk; ++i) for(size_t j = 0; j < 4; ++j) ctx->e_key[i][j] = k[4*i+j]; for(size_t i = ctx->Nk; i < 4 * (ctx->Nr + 1); ++i) { uint8_t temp[4]; for(size_t j = 0; j < 4; ++j) temp[j] = ctx->e_key[i-1][j]; if(i % ctx->Nk == 0) { RotWord(temp); SubWord(temp); temp[0] ^= rcon; rcon = xtime(rcon); } else if (ctx->Nk > 6 && i % ctx->Nk == 4) { SubWord(temp); } for(size_t j = 0; j < 4; ++j) ctx->e_key[i][j] = ctx->e_key[i-ctx->Nk][j] ^ temp[j]; } } void aes_encrypt(const aes_ctx * ctx, uint8_t * output, uint8_t const * input) { block_t block; for(size_t i = 0; i < 4; ++i) for(size_t j = 0; j < 4; ++j) block[i][j] = input[i * 4 + j]; AddRoundKey(block, &ctx->e_key[0]); for(size_t i = 1; i <= ctx->Nr - 1; ++i) { SubBytes(block); ShiftRows(block); MixColumns(block); AddRoundKey(block, &ctx->e_key[4 * i]); } SubBytes(block); ShiftRows(block); // No MixColumns AddRoundKey(block, &ctx->e_key[4 * ctx->Nr]); for(size_t i = 0; i < 4; ++i) for(size_t j = 0; j < 4; ++j) output[i * 4 + j] = block[i][j]; } #include int main() { // test vectors from FIPS-197 { uint8_t k[32]; uint8_t b[16]; for(size_t i = 0; i < 16; ++i) b[i] = i * 0x11; for(size_t i = 0; i < 32; ++i) k[i] = i; aes_ctx ctx; aes_set_key(&ctx, k, 128); aes_encrypt(&ctx, b, b); for(size_t i = 0; i < 16; ++i) printf("%02x", b[i]); printf("\n"); printf("69c4e0d86a7b0430d8cdb78070b4c55a\n"); } { uint8_t k[32]; uint8_t b[16]; for(size_t i = 0; i < 16; ++i) b[i] = i * 0x11; for(size_t i = 0; i < 32; ++i) k[i] = i; aes_ctx ctx; aes_set_key(&ctx, k, 256); aes_encrypt(&ctx, b, b); for(size_t i = 0; i < 16; ++i) printf("%02x", b[i]); printf("\n"); printf("8ea2b7ca516745bfeafc49904b496089\n"); } return 0; }