-
Notifications
You must be signed in to change notification settings - Fork 0
/
aes_recover.py
130 lines (96 loc) · 3.62 KB
/
aes_recover.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import aes
import collections
import itertools
from utils import hex, compare, dump_diff
# FAULT_PATTERNS[idx] is the fault pattern that appears when a byte of the
# state column idx is faulted.
FAULT_PATTERNS = [
[0, 7, 10, 13],
[1, 4, 11, 14],
[2, 5, 8, 15],
[3, 6, 9, 12],
]
# The index where the faulted byte ends up after mixing
FAULT_DESTINATION = [
[0, 13, 10, 7],
[4, 1, 14, 11],
[8, 5, 2, 15],
[12, 9, 6, 3],
]
# The only things that matter are the value of the fault and the row of the
# fault.
# The column does not change the way the fault propagates (MixColumns mixes all
# columns independently).
# This is a pure function, and the list could be inlined.
def compute_propagation():
ret = []
for row in range(4):
for fault in range(256):
col = [0 if x != row else fault for x in range(4)]
aes.mix_single_column(col)
ret.append(col)
return ret
# Each value in FAULT_PROPAGATION is the difference introduced by a different
# (faulted_row, fault_value) pair.
FAULT_PROPAGATION = compute_propagation()
class Fault:
def __init__(self, output, column):
self.output = output
self.column = column
def recognize_fault_pattern(diff):
if len(diff) != 4:
return None
assert diff in FAULT_PATTERNS
return FAULT_PATTERNS.index(diff)
# Return all values of k such that the difference of values after partial
# decryption of the reference byte and of the faulted byte equals d
def compute_key_candidates(ref_b, out_b, d):
cands = []
for k in range(256):
if aes.inv_s_box[ref_b ^ k] ^ aes.inv_s_box[out_b ^ k] == d:
cands.append(k)
return cands
def recover_key(ref, faults):
# There are 4 subkeys
KEY_SETS = [collections.Counter() for _ in range(4)]
for fault in faults:
print(f"Analyzing {hex(fault.output)}")
key_indices = list(FAULT_DESTINATION[fault.column])
# We don't know what row was faulted, and we don't know what value
# was faulted, so we have to go over everything.
for D in FAULT_PROPAGATION:
candidates = []
# Check if D may be the fault that was introduced.
# There must be key candidates for all 4 bytes.
for i in range(4):
key_index = key_indices[i]
c = compute_key_candidates(
ref[key_index], fault.output[key_index], D[i]
)
if len(c) == 0:
break
candidates.append(c)
# If we have candidates for all 4 bytes, register them.
# There may be multiple candidates for each byte, so go over
# all 4-tuples of candidates and register them all.
if len(candidates) == 4:
for k in itertools.product(*candidates):
KEY_SETS[fault.column][k] += 1
# We have analysed all the faults, it is now highly likely that the most
# common 4-tuple of each fault pattern is the correct key.
key = [0] * 16
for pattern_idx in range(len(FAULT_DESTINATION)):
most_common = KEY_SETS[pattern_idx].most_common(1)[0][0]
for i, key_index in enumerate(FAULT_DESTINATION[pattern_idx]):
key[key_index] = most_common[i]
return key
def recover_aes_key(ref, faulted_outputs):
faults = []
for fault in faulted_outputs:
diffs = compare(ref, fault)
column_idx = recognize_fault_pattern(diffs)
if column_idx is None:
continue
dump_diff(ref, fault, diffs)
faults.append(Fault(fault, column_idx))
return recover_key(ref, faults)