From 315d4f675f412f56c21950eaade0fb1a12a45632 Mon Sep 17 00:00:00 2001 From: Aykut Yilmaz Date: Sat, 15 Jul 2023 00:20:16 +0200 Subject: [PATCH] fix(AES): fix bug causing wrong output sometimes --- dev_deps.ts | 26 ++-- src/aes/consts.ts | 61 ++++------ src/aes/mod.ts | 241 ++++++++++++++++---------------------- src/utils/bytes.ts | 12 ++ tests/aes.test.ts | 22 +++- tests/block-modes.test.ts | 66 +++++++++-- 6 files changed, 239 insertions(+), 189 deletions(-) create mode 100644 src/utils/bytes.ts diff --git a/dev_deps.ts b/dev_deps.ts index 08a3886..d09ca92 100644 --- a/dev_deps.ts +++ b/dev_deps.ts @@ -1,10 +1,22 @@ +export { parse as parseArgs } from "https://deno.land/std@0.194.0/flags/mod.ts"; +export { + assertEquals, + assertThrows +} from "https://deno.land/std@0.194.0/testing/asserts.ts"; export { bench, - runBenchmarks, + runBenchmarks } from "https://deno.land/std@0.92.0/testing/bench.ts"; -export { - assertEquals, - assertThrows, -} from "https://deno.land/std@0.92.0/testing/asserts.ts"; -export { parse as parseArgs } from "https://deno.land/std@0.92.0/flags/mod.ts"; -export { decodeString as decodeHex } from "https://deno.land/x/std@0.92.0/encoding/hex.ts"; +// export { +// decodeString as decodeHex, +// encodeToString as encodeHex, +// } from "https://deno.land/x/std@0.92.0/encoding/hex.ts"; +import { decode, encode } from "https://deno.land/std@0.194.0/encoding/hex.ts"; + +export function decodeHex(hex: string): Uint8Array { + return decode(new TextEncoder().encode(hex)); +} + +export function encodeHex(bytes: Uint8Array): string { + return new TextDecoder().decode(encode(bytes)); +} diff --git a/src/aes/consts.ts b/src/aes/consts.ts index 6b4fb4f..6cfea72 100644 --- a/src/aes/consts.ts +++ b/src/aes/consts.ts @@ -1,55 +1,46 @@ -export const S = new DataView(new ArrayBuffer(256)); -export const SI = new DataView(new ArrayBuffer(256)); - -export const T1 = new DataView(new ArrayBuffer(1024)); -export const T2 = new DataView(new ArrayBuffer(1024)); -export const T3 = new DataView(new ArrayBuffer(1024)); -export const T4 = new DataView(new ArrayBuffer(1024)); -export const T5 = new DataView(new ArrayBuffer(1024)); -export const T6 = new DataView(new ArrayBuffer(1024)); -export const T7 = new DataView(new ArrayBuffer(1024)); -export const T8 = new DataView(new ArrayBuffer(1024)); +const buffer = new ArrayBuffer(8704); + +export const S = new Uint8Array(buffer, 0, 256); +export const SI = new Uint8Array(buffer, 256, 256); +export const T1 = new Uint32Array(buffer, 512, 256); +export const T2 = new Uint32Array(buffer, 1536, 256); +export const T3 = new Uint32Array(buffer, 2560, 256); +export const T4 = new Uint32Array(buffer, 3584, 256); +export const T5 = new Uint32Array(buffer, 4608, 256); +export const T6 = new Uint32Array(buffer, 5632, 256); +export const T7 = new Uint32Array(buffer, 6656, 256); +export const T8 = new Uint32Array(buffer, 7680, 256); const d = new Uint8Array(256); const t = new Uint8Array(256); +let x2, x4, x8, s, tEnc, tDec, x = 0, xInv = 0; + for (let i = 0; i < 256; i++) { d[i] = i << 1 ^ (i >> 7) * 283; t[d[i] ^ i] = i; } -let x2, x4, x8, s, tEnc, tDec, xInv = 0; - -for (let x = 0; !S.getUint8(x); x ^= x2 || 1) { +for (; !S[x]; x ^= x2 || 1) { s = xInv ^ xInv << 1 ^ xInv << 2 ^ xInv << 3 ^ xInv << 4; s = s >> 8 ^ s & 255 ^ 99; - S.setUint8(x, s); - SI.setUint8(s, x); + S[x] = s; + SI[s] = x; x8 = d[x4 = d[x2 = d[x]]]; tDec = x8 * 0x1010101 ^ x4 * 0x10001 ^ x2 * 0x101 ^ x * 0x1010100; tEnc = d[s] * 0x101 ^ s * 0x1010100; - const i = x * 4; - - tEnc = tEnc << 24 ^ tEnc >>> 8; - T1.setUint32(i, tEnc); - tEnc = tEnc << 24 ^ tEnc >>> 8; - T2.setUint32(i, tEnc); - tEnc = tEnc << 24 ^ tEnc >>> 8; - T3.setUint32(i, tEnc); - tEnc = tEnc << 24 ^ tEnc >>> 8; - T4.setUint32(i, tEnc); - - tDec = tDec << 24 ^ tDec >>> 8; - T5.setUint32(s * 4, tDec); - tDec = tDec << 24 ^ tDec >>> 8; - T6.setUint32(s * 4, tDec); - tDec = tDec << 24 ^ tDec >>> 8; - T7.setUint32(s * 4, tDec); - tDec = tDec << 24 ^ tDec >>> 8; - T8.setUint32(s * 4, tDec); + T1[x] = tEnc = tEnc << 24 ^ tEnc >>> 8; + T2[x] = tEnc = tEnc << 24 ^ tEnc >>> 8; + T3[x] = tEnc = tEnc << 24 ^ tEnc >>> 8; + T4[x] = tEnc = tEnc << 24 ^ tEnc >>> 8; + + T5[s] = tDec = tDec << 24 ^ tDec >>> 8; + T6[s] = tDec = tDec << 24 ^ tDec >>> 8; + T7[s] = tDec = tDec << 24 ^ tDec >>> 8; + T8[s] = tDec = tDec << 24 ^ tDec >>> 8; xInv = t[xInv] || 1; } diff --git a/src/aes/mod.ts b/src/aes/mod.ts index 60e7602..39c4939 100644 --- a/src/aes/mod.ts +++ b/src/aes/mod.ts @@ -1,5 +1,6 @@ -import { S, SI, T1, T2, T3, T4, T5, T6, T7, T8 } from "./consts.ts"; import { BlockCipher } from "../block-modes/base.ts"; +import { bytesToWords } from "../utils/bytes.ts"; +import { S, SI, T1, T2, T3, T4, T5, T6, T7, T8 } from "./consts.ts"; /** * Advanced Encryption Standard (AES) block cipher. @@ -12,8 +13,8 @@ export class Aes implements BlockCipher { * The block size of the block cipher in bytes */ static readonly BLOCK_SIZE = 16; - #ke: DataView; - #kd: DataView; + #ke: Uint32Array; + #kd: Uint32Array; #nr: number; constructor(key: Uint8Array) { @@ -21,167 +22,131 @@ export class Aes implements BlockCipher { throw new Error("Invalid key size (must be either 16, 24 or 32 bytes)"); } - const keyView = new DataView(key.buffer, key.byteOffset, key.byteLength); const keyLen = key.length / 4; const rkc = key.length + 28; + const ke = new Uint32Array(rkc); + ke.set(bytesToWords(key), 0); + const kd = new Uint32Array(rkc); - this.#nr = (rkc - 4) * 4; - this.#ke = new DataView(new ArrayBuffer(rkc * 4)); - this.#kd = new DataView(new ArrayBuffer(rkc * 4)); - - for (let i = 0; i < key.length; i += 4) { - this.#ke.setUint32(i * 4, keyView.getUint32(i)); - } - - let rcon = 1; - for (let i = keyLen; i < rkc; i++) { - let tmp = this.#ke.getUint32((i - 1) * 4); + let i, j, tmp, rcon = 1; + for (i = keyLen; i < 4 * keyLen + 28; i++) { + tmp = ke[i - 1]; if (i % keyLen === 0 || (keyLen === 8 && i % keyLen === 4)) { - tmp = S.getUint8(tmp >>> 24) << 24 ^ - S.getUint8(tmp >> 16 & 0xff) << 16 ^ - S.getUint8(tmp >> 8 & 0xff) << 8 ^ - S.getUint8(tmp & 0xff); + tmp = S[tmp >>> 24] << 24 ^ S[tmp >> 16 & 255] << 16 ^ + S[tmp >> 8 & 255] << 8 ^ S[tmp & 255]; if (i % keyLen === 0) { tmp = tmp << 8 ^ tmp >>> 24 ^ rcon << 24; - rcon = rcon << 1 ^ (rcon >> 7) * 0x11b; + rcon = rcon << 1 ^ (rcon >> 7) * 283; } } - this.#ke.setUint32( - i * 4, - this.#ke.getUint32((i - keyLen) * 4) ^ tmp, - ); + ke[i] = ke[i - keyLen] ^ tmp; } - for (let j = 0, i = rkc; i; j++, i--) { - const tmp = this.#ke.getUint32(j & 3 ? i * 4 : (i - 4) * 4); + for (j = 0; i; j++, i--) { + tmp = ke[j & 3 ? i : i - 4]; if (i <= 4 || j < 4) { - this.#kd.setUint32(j * 4, tmp); + kd[j] = tmp; } else { - this.#kd.setUint32( - j * 4, - T5.getUint32(S.getUint8(tmp >>> 24) * 4) ^ - T6.getUint32(S.getUint8(tmp >> 16 & 0xff) * 4) ^ - T7.getUint32(S.getUint8(tmp >> 8 & 0xff) * 4) ^ - T8.getUint32(S.getUint8(tmp & 0xff) * 4), - ); + kd[j] = T5[S[tmp >>> 24]] ^ + T6[S[tmp >> 16 & 255]] ^ + T7[S[tmp >> 8 & 255]] ^ + T8[S[tmp & 255]]; } } + + this.#nr = ke.length / 4 - 2; + this.#ke = ke; + this.#kd = kd; } encryptBlock(data: DataView, offset: number) { - let t0 = data.getUint32(offset) ^ this.#ke.getUint32(0); - let t1 = data.getUint32(offset + 4) ^ this.#ke.getUint32(4); - let t2 = data.getUint32(offset + 8) ^ this.#ke.getUint32(8); - let t3 = data.getUint32(offset + 12) ^ this.#ke.getUint32(12); - let a0, a1, a2; - - for (let i = 16; i < this.#nr; i += 16) { - a0 = T1.getUint32((t0 >>> 24) * 4) ^ - T2.getUint32((t1 >> 16 & 0xff) * 4) ^ - T3.getUint32((t2 >> 8 & 0xff) * 4) ^ - T4.getUint32((t3 & 0xff) * 4) ^ - this.#ke.getUint32(i); - a1 = T1.getUint32((t1 >>> 24) * 4) ^ - T2.getUint32((t2 >> 16 & 0xff) * 4) ^ - T3.getUint32((t3 >> 8 & 0xff) * 4) ^ - T4.getUint32((t0 & 0xff) * 4) ^ - this.#ke.getUint32(i + 4); - a2 = T1.getUint32((t2 >>> 24) * 4) ^ - T2.getUint32((t3 >> 16 & 0xff) * 4) ^ - T3.getUint32((t0 >> 8 & 0xff) * 4) ^ - T4.getUint32((t1 & 0xff) * 4) ^ - this.#ke.getUint32(i + 8); - t3 = T1.getUint32((t3 >>> 24) * 4) ^ - T2.getUint32((t0 >> 16 & 0xff) * 4) ^ - T3.getUint32((t1 >> 8 & 0xff) * 4) ^ - T4.getUint32((t2 & 0xff) * 4) ^ - this.#ke.getUint32(i + 12); - t0 = a0, t1 = a1, t2 = a2; + const k = this.#ke; + let a = data.getUint32(offset + 0) ^ k[0], + b = data.getUint32(offset + 4) ^ k[1], + c = data.getUint32(offset + 8) ^ k[2], + d = data.getUint32(offset + 12) ^ k[3], + a2, + b2, + c2, + i, + ki = 4; + + for (i = 0; i < this.#nr; i++) { + a2 = T1[a >>> 24] ^ T2[b >> 16 & 255] ^ T3[c >> 8 & 255] ^ T4[d & 255] ^ + k[ki]; + b2 = T1[b >>> 24] ^ T2[c >> 16 & 255] ^ T3[d >> 8 & 255] ^ T4[a & 255] ^ + k[ki + 1]; + c2 = T1[c >>> 24] ^ T2[d >> 16 & 255] ^ T3[a >> 8 & 255] ^ T4[b & 255] ^ + k[ki + 2]; + d = T1[d >>> 24] ^ T2[a >> 16 & 255] ^ T3[b >> 8 & 255] ^ T4[c & 255] ^ + k[ki + 3]; + ki += 4; + a = a2; + b = b2; + c = c2; } - data.setUint32( - offset, - S.getUint8(t0 >>> 24) << 24 ^ S.getUint8(t1 >> 16 & 0xff) << 16 ^ - S.getUint8(t2 >> 8 & 0xff) << 8 ^ S.getUint8(t3 & 0xff) ^ - this.#ke.getUint32(this.#nr), - ); - data.setUint32( - offset + 4, - S.getUint8(t1 >>> 24) << 24 ^ S.getUint8(t2 >> 16 & 0xff) << 16 ^ - S.getUint8(t3 >> 8 & 0xff) << 8 ^ S.getUint8(t0 & 0xff) ^ - this.#ke.getUint32(this.#nr + 4), - ); - data.setUint32( - offset + 8, - S.getUint8(t2 >>> 24) << 24 ^ S.getUint8(t3 >> 16 & 0xff) << 16 ^ - S.getUint8(t0 >> 8 & 0xff) << 8 ^ S.getUint8(t1 & 0xff) ^ - this.#ke.getUint32(this.#nr + 8), - ); - data.setUint32( - offset + 12, - S.getUint8(t3 >>> 24) << 24 ^ S.getUint8(t0 >> 16 & 0xff) << 16 ^ - S.getUint8(t1 >> 8 & 0xff) << 8 ^ S.getUint8(t2 & 0xff) ^ - this.#ke.getUint32(this.#nr + 12), - ); + for (i = 0; i < 4; i++) { + data.setUint32( + offset + i * 4, + S[a >>> 24] << 24 ^ + S[b >> 16 & 255] << 16 ^ + S[c >> 8 & 255] << 8 ^ + S[d & 255] ^ + k[ki++], + ); + a2 = a; + a = b; + b = c; + c = d; + d = a2; + } } decryptBlock(data: DataView, offset: number) { - let t0 = data.getUint32(offset) ^ this.#kd.getUint32(0); - let t1 = data.getUint32(offset + 4) ^ this.#kd.getUint32(12); - let t2 = data.getUint32(offset + 8) ^ this.#kd.getUint32(8); - let t3 = data.getUint32(offset + 12) ^ this.#kd.getUint32(4); - let a0, a1, a2; - - for (let i = 16; i < this.#nr; i += 16) { - a0 = T5.getUint32((t0 >>> 24) * 4) ^ - T6.getUint32((t3 >> 16 & 0xff) * 4) ^ - T7.getUint32((t2 >> 8 & 0xff) * 4) ^ - T8.getUint32((t1 & 0xff) * 4) ^ - this.#kd.getUint32(i); - a1 = T5.getUint32((t1 >>> 24) * 4) ^ - T6.getUint32((t0 >> 16 & 0xff) * 4) ^ - T7.getUint32((t3 >> 8 & 0xff) * 4) ^ - T8.getUint32((t2 & 0xff) * 4) ^ - this.#kd.getUint32(i + 12); - a2 = T5.getUint32((t2 >>> 24) * 4) ^ - T6.getUint32((t1 >> 16 & 0xff) * 4) ^ - T7.getUint32((t0 >> 8 & 0xff) * 4) ^ - T8.getUint32((t3 & 0xff) * 4) ^ - this.#kd.getUint32(i + 8); - t3 = T5.getUint32((t3 >>> 24) * 4) ^ - T6.getUint32((t2 >> 16 & 0xff) * 4) ^ - T7.getUint32((t1 >> 8 & 0xff) * 4) ^ - T8.getUint32((t0 & 0xff) * 4) ^ - this.#kd.getUint32(i + 4); - t0 = a0, t1 = a1, t2 = a2; + const k = this.#kd; + let a = data.getUint32(offset + 0) ^ k[0], + b = data.getUint32(offset + 12) ^ k[1], + c = data.getUint32(offset + 8) ^ k[2], + d = data.getUint32(offset + 4) ^ k[3], + a2, + b2, + c2, + i, + ki = 4; + + for (i = 0; i < this.#nr; i++) { + a2 = T5[a >>> 24] ^ T6[b >> 16 & 255] ^ T7[c >> 8 & 255] ^ T8[d & 255] ^ + k[ki]; + b2 = T5[b >>> 24] ^ T6[c >> 16 & 255] ^ T7[d >> 8 & 255] ^ T8[a & 255] ^ + k[ki + 1]; + c2 = T5[c >>> 24] ^ T6[d >> 16 & 255] ^ T7[a >> 8 & 255] ^ T8[b & 255] ^ + k[ki + 2]; + d = T5[d >>> 24] ^ T6[a >> 16 & 255] ^ T7[b >> 8 & 255] ^ T8[c & 255] ^ + k[ki + 3]; + ki += 4; + a = a2; + b = b2; + c = c2; } - data.setUint32( - offset, - SI.getUint8(t0 >>> 24) << 24 ^ SI.getUint8(t3 >> 16 & 0xff) << 16 ^ - SI.getUint8(t2 >> 8 & 0xff) << 8 ^ SI.getUint8(t1 & 0xff) ^ - this.#kd.getUint32(this.#nr), - ); - data.setUint32( - offset + 4, - SI.getUint8(t1 >>> 24) << 24 ^ SI.getUint8(t0 >> 16 & 0xff) << 16 ^ - SI.getUint8(t3 >> 8 & 0xff) << 8 ^ SI.getUint8(t2 & 0xff) ^ - this.#kd.getUint32(this.#nr + 12), - ); - data.setUint32( - offset + 8, - SI.getUint8(t2 >>> 24) << 24 ^ SI.getUint8(t1 >> 16 & 0xff) << 16 ^ - SI.getUint8(t0 >> 8 & 0xff) << 8 ^ SI.getUint8(t3 & 0xff) ^ - this.#kd.getUint32(this.#nr + 8), - ); - data.setUint32( - offset + 12, - SI.getUint8(t3 >>> 24) << 24 ^ SI.getUint8(t2 >> 16 & 0xff) << 16 ^ - SI.getUint8(t1 >> 8 & 0xff) << 8 ^ SI.getUint8(t0 & 0xff) ^ - this.#kd.getUint32(this.#nr + 4), - ); + for (i = 0; i < 4; i++) { + data.setUint32( + offset + (3 & -i) * 4, + SI[a >>> 24] << 24 ^ + SI[b >> 16 & 255] << 16 ^ + SI[c >> 8 & 255] << 8 ^ + SI[d & 255] ^ + k[ki++], + ); + a2 = a; + a = b; + b = c; + c = d; + d = a2; + } } } diff --git a/src/utils/bytes.ts b/src/utils/bytes.ts new file mode 100644 index 0000000..d17d3d0 --- /dev/null +++ b/src/utils/bytes.ts @@ -0,0 +1,12 @@ +export function bytesToWords(bytes: Uint8Array): Uint32Array { + const dataView = new DataView( + bytes.buffer, + bytes.byteOffset, + bytes.byteLength, + ); + const words = new Uint32Array(bytes.length / 4); + for (let i = 0; i < words.length; i++) { + words[i] = dataView.getUint32(i * 4); + } + return words; +} \ No newline at end of file diff --git a/tests/aes.test.ts b/tests/aes.test.ts index 0ac5066..b120b5d 100644 --- a/tests/aes.test.ts +++ b/tests/aes.test.ts @@ -1,5 +1,5 @@ -import { assertEquals, assertThrows, decodeHex } from "../dev_deps.ts"; import { Aes } from "../aes.ts"; +import { assertEquals, assertThrows, decodeHex } from "../dev_deps.ts"; // https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Algorithm-Validation-Program/documents/aes/AESAVS.pdf @@ -74,3 +74,23 @@ Deno.test("[Block Cipher] AES-256", () => { assertEquals(data, decodeHex(plaintext)); } }); + +Deno.test("[Block Cipher] AES-128 2.0", () => { + const plaintext = "00000000000000000000000000000000"; + const testVectors: readonly [string, string][] = [ + ["10a58869d74be5a374cf867cfb473859", "6d251e6944b051e04eaa6fb4dbf78465"], + ["caea65cdbb75e9169ecd22ebe6e54675", "6e29201190152df4ee058139def610bb"], + ["a2e2fa9baf7d20822ca9f0542f764a41", "c3b44b95d9d2f25670eee9a0de099fa3"], + ]; + + for (const [key, chiphertext] of testVectors) { + const aes = new Aes(decodeHex(key)); + const data = decodeHex(plaintext); + const dataView = new DataView(data.buffer); + aes.encryptBlock(dataView, 0); + assertEquals(data, decodeHex(chiphertext)); + + aes.decryptBlock(dataView, 0); + assertEquals(data, decodeHex(plaintext)); + } +}); \ No newline at end of file diff --git a/tests/block-modes.test.ts b/tests/block-modes.test.ts index de0fa17..3d502d2 100644 --- a/tests/block-modes.test.ts +++ b/tests/block-modes.test.ts @@ -1,6 +1,10 @@ -import { assertEquals, assertThrows } from "../dev_deps.ts"; -import { Cbc, Cfb, Ctr, Ecb, Ofb } from "../block-modes.ts"; import { Aes } from "../aes.ts"; +import { Cbc, Cfb, Ctr, Ecb, Ofb } from "../block-modes.ts"; +import { + assertEquals, + assertThrows, + decodeHex +} from "../dev_deps.ts"; const key = new Uint8Array(16); const iv = new Uint8Array(16); @@ -15,7 +19,7 @@ Deno.test("[Block Cipher Mode] Base", () => { cipher.decrypt(new Uint8Array(4)); }, Error, - "Invalid data size (must be multiple of 16 bytes)", + "Invalid data size (must be multiple of 16 bytes)" ); assertThrows( @@ -24,17 +28,63 @@ Deno.test("[Block Cipher Mode] Base", () => { cipher.encrypt(new Uint8Array(4)); }, Error, - "Invalid initialization vector size (must be 16 bytes)", + "Invalid initialization vector size (must be 16 bytes)" ); }); Deno.test("[Block Cipher Mode] ECB", () => { - const cipher = new Ecb(Aes, key); - const enc = cipher.encrypt(original); - const dec = cipher.decrypt(enc); - assertEquals(dec, original); + interface TestVector { + key: string; + plain: string; + cipher: string; + } + + const testVectors: TestVector[] = [ + { + key: "000102030405060708090a0b0c0d0e0f", + plain: "00112233445566778899aabbccddeeff", + cipher: "69c4e0d86a7b0430d8cdb78070b4c55a", + }, + { + key: "2b7e151628aed2a6abf7158809cf4f3c", + plain: "6bc1bee22e409f96e93d7e117393172a", + cipher: "3ad77bb40d7a3660a89ecaf32466ef97", + }, + { + key: "2b7e151628aed2a6abf7158809cf4f3c", + plain: "ae2d8a571e03ac9c9eb76fac45af8e51", + cipher: "f5d3d58503b9699de785895a96fdbaaf", + }, + { + key: "2b7e151628aed2a6abf7158809cf4f3c", + plain: "30c81c46a35ce411e5fbc1191a0a52ef", + cipher: "43b1cd7f598ece23881b00e3ed030688", + }, + { + key: "2b7e151628aed2a6abf7158809cf4f3c", + plain: "f69f2445df4f9b17ad2b417be66c3710", + cipher: "7b0c785e27e8ad3f8223207104725dd4", + }, + ]; + + for (const testVector of testVectors) { + const key = decodeHex(testVector.key); + const plain = decodeHex(testVector.plain); + const encrypted = decodeHex(testVector.cipher); + + const cipher = new Ecb(Aes, key); + const enc = cipher.encrypt(plain); + + assertEquals(plain, decodeHex(testVector.plain)); + assertEquals(enc, encrypted); + + const dec = cipher.decrypt(enc); + assertEquals(dec, plain); + } + }); + Deno.test("[Block Cipher Mode] CBC", () => { const cipher = new Cbc(Aes, key, iv); const decipher = new Cbc(Aes, key, iv);