Tampering Detection System
Co-authors: @JOHNKRAM @Rosayxy
Attachment:
#!/usr/local/bin/python
import base64
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.strxor import strxor
try:
with open("/flag.txt", "rb") as f:
FLAG = f.read()
except FileNotFoundError:
FLAG = b"FLAG{******** REDACTED ********}"
def encrypt(plaintext: bytes, key: bytes, nonce: bytes, associated_data: bytes = b""):
cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
cipher.update(associated_data)
ciphertext, tag = cipher.encrypt_and_digest(plaintext)
return ciphertext, tag
def decrypt(ciphertext: bytes, key: bytes, nonce: bytes, tag: bytes, associated_data: bytes = b""):
cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
cipher.update(associated_data)
plaintext = cipher.decrypt_and_verify(ciphertext, tag)
return plaintext
def query(ciphertext: bytes, aad: bytes, key: bytes, tag: bytes, nonce: bytes):
cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
try:
cipher.update(aad)
cipher.decrypt_and_verify(ciphertext, tag)
return True
except:
return False
nonce = get_random_bytes(12)
key = get_random_bytes(32)
flag_ciphertext, flag_tag = encrypt(FLAG, key, nonce, b"")
print("flag ciphertext: ", base64.b64encode(flag_ciphertext).decode())
print("flag tag: ", base64.b64encode(flag_tag).decode())
user_plaintext = input("your_text1:")
ciphertext, tag = encrypt(user_plaintext.encode(), key, nonce, b"")
print("tag1: ", base64.b64encode(tag).decode())
user_plaintext = input("your_text2:")
ciphertext, tag = encrypt(user_plaintext.encode(), key, nonce, b"")
print("tag2: ", base64.b64encode(tag).decode())
while True:
length = int(input("length:"))
aad = base64.b64decode(input("aad: "))
print(query(ciphertext[:length], aad, key, tag, nonce))
The server reuses nonce for AES-GCM mode, which is vulnerable. We deployed the attack according to AES-GCM and breaking it on nonce reuse:
- We obtain two pairs of (plain, tag): (plain1, tag1) and (plain2, tag2), and because they are encrypted using the same key and nonce, so the cipher texts satisfy:
cipher1 xor cipher2 == plain1 xor plain2when they have the same length (pointed out by @JOHNKRAM). - Given the two known pairs, we can create a polynomial of H by xor-ing the computation of tag1 and tag2. We can find its roots using SageMath. See the article above for a detailed description.
- After solving H, we conduct an attack similar to CBC Padding Oracle Attack, but from the front: from length of one to flag full length, enumerate each byte and compute the authentication data that computes to the same tag. If it replys with True, we found the correct byte.
Attack script:
from pwn import *
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.number import long_to_bytes, bytes_to_long
from Crypto.Util.strxor import strxor
from Crypto.Util.Padding import pad
from Crypto.Util.strxor import strxor
import base64
import galois
GF = galois.GF(2**128, irreducible_poly="x^128 + x^7 + x^2 + x + 1")
def to_gf(num):
# from LSB to MSB
ret = GF(0)
assert num <= 2**128
for i in range(128):
if num & (1 << i) != 0:
ret += GF(2 ** (127 - i))
return ret
def from_gf(gf):
# from LSB to MSB
num = int(gf)
ret = 0
for i in range(128):
if num & (1 << i) != 0:
ret += 2 ** (127 - i)
return ret
# context(log_level="DEBUG")
def gcm_nonce_reuse(cipher1: bytes, tag1: bytes, cipher2: bytes, tag2: bytes) -> bytes:
"""
AES GCM Nonce Reuse attack, learned from https://frereit.de/aes_gcm/,
given two pairs of known (cipher, tag) encrypted using the same key and associated data,
recover H = AES-ECB-Encrypt(key, b"\x00" * 16),
because cipher1 xor cipher2 == plain1 xor plain2, plaintext can be used instead
Args:
cipher1 (bytes): the first ciphertext
tag1 (bytes): the first tag
cipher2 (bytes): the second ciphertext
tag2 (bytes): the second tag
Returns:
H: the array of recovered H
"""
from sage.all import GF, PolynomialRing
# use sage to solve polynomials on GF(2^128)
F = GF(2)["a"]
(a,) = F._first_ngens(1)
F = GF(2**128, modulus=a**128 + a**7 + a**2 + a + 1, names=("x",))
(x,) = F._first_ngens(1)
R = PolynomialRing(F, names=("H",))
(H,) = R._first_ngens(1)
# construct polynomial based on two tuples
polys = []
for cipher, tag in [(cipher1, tag1), (cipher2, tag2)]:
poly = 0
blocks = []
# plaintext part
for i in range((len(cipher) + 15) // 16):
part = cipher[i * 16 : (i + 1) * 16]
if len(part) < 16:
part += b"\x00" * (16 - len(part))
blocks.append(part)
# len(plain) part
blocks.append(long_to_bytes(len(cipher) * 8, 16))
# compute poly for blocks
for block in blocks:
temp = 0
for i in range(128):
if bytes_to_long(block) & (1 << i) != 0:
temp += x ** (127 - i)
poly = poly * H + temp
# compute poly for tag part
temp = 0
for i in range(128):
if bytes_to_long(tag) & (1 << i) != 0:
temp += x ** (127 - i)
poly = poly * H + temp
polys.append(poly)
roots = (polys[0] + polys[1]).roots()
res = []
for root in roots:
coefs = root[0].list()
H = 0
for i in range(128):
if coefs[i] != 0:
H += 2 ** (127 - i)
res.append(long_to_bytes(H))
return res
p = process(["python3", "./server.py"])
# step 1: recover H
p.recvuntil(b"flag ciphertext: ")
flag_cipher = base64.b64decode(p.recvline().decode())
p.recvuntil(b"flag tag: ")
flag_tag = base64.b64decode(p.recvline().decode())
plain_len = 128
plain1 = b"A" * plain_len
p.recvuntil(b"your_text1:")
p.sendline(plain1)
p.recvuntil(b"tag1: ")
tag1 = base64.b64decode(p.recvline().decode())
plain2 = b"B" * plain_len
p.recvuntil(b"your_text2:")
p.sendline(plain2)
p.recvuntil(b"tag2: ")
tag2 = base64.b64decode(p.recvline().decode())
# plain1 xor plain2 == cipher1 xor cipher2
res = gcm_nonce_reuse(plain1, tag1, plain2, tag2)
for H in res:
H_gf = to_gf(bytes_to_long(H))
H_gf_inverse = H_gf**-1
# compute J0_enc
T = 0
for i in range((len(flag_cipher) + 15) // 16):
part = flag_cipher[i * 16 : (i + 1) * 16]
# update
T = T * H_gf + to_gf(bytes_to_long(part) * (256 ** (16 - len(part))))
L = len(flag_cipher) * 8
T = T * H_gf + to_gf(L)
T = T * H_gf + to_gf(bytes_to_long(flag_tag))
J0_enc = T
# enumerate cipher2
cipher2 = bytearray([0] * plain_len)
for length in range(1, 64):
for b in range(256):
cipher2[length-1] = b
cur_cipher2 = cipher2[:length]
T = 0
blocks = (len(cur_cipher2) + 15) // 16
for i in range(blocks):
part = cur_cipher2[i * 16 : (i + 1) * 16]
# update
T = T * H_gf + to_gf(bytes_to_long(part) * (256 ** (16 - len(part))))
L = len(cur_cipher2) * 8 + 16 * 8 * (2 ** 64)
T = T * H_gf + to_gf(L)
T = T * H_gf + J0_enc + to_gf(bytes_to_long(tag2))
# found required aad to match tag
aad = T * H_gf_inverse ** (2+blocks)
# correct guess?
p.recvuntil(b"length:")
p.sendline(str(length).encode())
p.recvuntil(b"aad: ")
aad_data = long_to_bytes(from_gf(aad), 16)
p.sendline(base64.b64encode(aad_data))
resp = p.recvline()
if resp == b"True\n":
print("Found", b)
# recover plain using the current cipher
l = min(len(flag_cipher), len(plain2))
temp = strxor(bytes(cipher2[:l]), flag_cipher[:l])
temp = strxor(temp, plain2[:l])
print("flag", temp)
break
elif resp != b"False\n":
assert False, resp
Alternative solution
An alternative solution attacks from the back to the front, which is more complex but does not require flag_tag:
- After solving H, we conduct an attack similar to CBC Padding Oracle Attack from the back: shorten the length one by one, enumerate the truncated byte, and compute the authentication data that computes to the same tag. By changing the length of the plaintext, we can recover different parts of the flag (pointed out by @JOHNKRAM).
- We can recover most of the flag now (except byte at 0, 16, 32). The rest can be simply bruteforced, due to the small space. Given the flag, we can recover
AES-ECB-Encrypt(key, nonce || 1)and verify if tag1 can be computed correctly.
Attack script (modified to match server's flag length of 42 bytes):
from pwn import *
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.number import long_to_bytes, bytes_to_long
from Crypto.Util.strxor import strxor
from Crypto.Util.Padding import pad
from Crypto.Util.strxor import strxor
import base64
import galois
GF = galois.GF(2**128, irreducible_poly="x^128 + x^7 + x^2 + x + 1")
def to_gf(num):
# from LSB to MSB
ret = GF(0)
assert num <= 2**128
for i in range(128):
if num & (1 << i) != 0:
ret += GF(2 ** (127 - i))
return ret
def from_gf(gf):
# from LSB to MSB
num = int(gf)
ret = 0
for i in range(128):
if num & (1 << i) != 0:
ret += 2 ** (127 - i)
return ret
# context(log_level="DEBUG")
def gcm_nonce_reuse(cipher1: bytes, tag1: bytes, cipher2: bytes, tag2: bytes) -> bytes:
"""
AES GCM Nonce Reuse attack, learned from https://frereit.de/aes_gcm/,
given two pairs of known (cipher, tag) encrypted using the same key and associated data,
recover H = AES-ECB-Encrypt(key, b"\x00" * 16),
because cipher1 xor cipher2 == plain1 xor plain2, plaintext can be used instead
Args:
cipher1 (bytes): the first ciphertext
tag1 (bytes): the first tag
cipher2 (bytes): the second ciphertext
tag2 (bytes): the second tag
Returns:
H: the array of recovered H
"""
from sage.all import GF, PolynomialRing
# use sage to solve polynomials on GF(2^128)
F = GF(2)["a"]
(a,) = F._first_ngens(1)
F = GF(2**128, modulus=a**128 + a**7 + a**2 + a + 1, names=("x",))
(x,) = F._first_ngens(1)
R = PolynomialRing(F, names=("H",))
(H,) = R._first_ngens(1)
# construct polynomial based on two tuples
polys = []
for cipher, tag in [(cipher1, tag1), (cipher2, tag2)]:
poly = 0
blocks = []
# plaintext part
for i in range((len(cipher) + 15) // 16):
part = cipher[i * 16 : (i + 1) * 16]
if len(part) < 16:
part += b"\x00" * (16 - len(part))
blocks.append(part)
# len(plain) part
blocks.append(long_to_bytes(len(cipher) * 8, 16))
# compute poly for blocks
for block in blocks:
temp = 0
for i in range(128):
if bytes_to_long(block) & (1 << i) != 0:
temp += x ** (127 - i)
poly = poly * H + temp
# compute poly for tag part
temp = 0
for i in range(128):
if bytes_to_long(tag) & (1 << i) != 0:
temp += x ** (127 - i)
poly = poly * H + temp
polys.append(poly)
roots = (polys[0] + polys[1]).roots()
res = []
for root in roots:
coefs = root[0].list()
H = 0
for i in range(128):
if coefs[i] != 0:
H += 2 ** (127 - i)
res.append(long_to_bytes(H))
return res
flag_parts = []
for plain_len in [16, 32, 48]:
p = process(["python3", "./server.py"])
# step 1: recover H
p.recvuntil(b"flag ciphertext: ")
flag_cipher = base64.b64decode(p.recvline().decode())
p.recvuntil(b"flag tag: ")
flag_tag = base64.b64decode(p.recvline().decode())
plain1 = b"A" * plain_len
p.recvuntil(b"your_text1:")
p.sendline(plain1)
p.recvuntil(b"tag1: ")
tag1 = base64.b64decode(p.recvline().decode())
plain2 = b"B" * plain_len
p.recvuntil(b"your_text2:")
p.sendline(plain2)
p.recvuntil(b"tag2: ")
tag2 = base64.b64decode(p.recvline().decode())
# plain1 xor plain2 == cipher1 xor cipher2
res = gcm_nonce_reuse(plain1, tag1, plain2, tag2)
print("Got", len(res), "H")
for H in res:
print(H)
H_gf = to_gf(bytes_to_long(H))
H_gf_inverse = H_gf**-1
# for all possible ciphertext, compute corresponding AAD
cipher = [0] * len(plain2)
# from right
for i in range(16):
# enumerate the next byte
good = False
for b in range(256):
length = len(plain2) - i - 1
cipher[length] = b
# compute contribution of the truncated bytes
part = cipher[length - 1 :]
contribution_b = to_gf(bytes_to_long(bytes(part))) * H_gf**2
# compute contribution of len(CT)
contribution_len_ct = (
to_gf(len(plain2) * 8) + to_gf(length * 8)
) * H_gf
# compute contribution of len(AAD)
contribution_len_aad = to_gf(128 * (2**64)) * H_gf
blocks = len(plain2) // 16
aad = (
contribution_b + contribution_len_ct + contribution_len_aad
) * H_gf_inverse ** (2 + blocks)
# cancel out
contribution = (
contribution_b
+ contribution_len_ct
+ contribution_len_aad
+ aad * (H_gf) ** (2 + blocks)
)
assert contribution == 0
# correct guess?
p.recvuntil(b"length:")
p.sendline(str(length).encode())
p.recvuntil(b"aad: ")
aad_data = long_to_bytes(from_gf(aad), 16)
p.sendline(base64.b64encode(aad_data))
resp = p.recvline()
if resp == b"True\n":
print("Found", b)
good = True
break
elif resp != b"False\n":
assert False, resp
if not good:
print("Bad H")
break
cipher[length] = b
print("cipher", bytes(cipher))
# recover plain using the current cipher
l = min(len(flag_cipher), len(plain2))
temp = strxor(bytes(cipher[:l]), flag_cipher[:l])
temp = strxor(temp, plain2[:l])
print("flag", temp)
if i >= 15:
flag_parts.append(temp)
print("Done", i)
break
flag = flag_parts[0] + flag_parts[1][16:] + flag_parts[2][32:]
print(flag)
# find flag[0, 16, 32]
p = process(["python3", "./server.py"])
# step 1: recover H
p.recvuntil(b"flag ciphertext: ")
flag_cipher = base64.b64decode(p.recvline().decode())
p.recvuntil(b"flag tag: ")
flag_tag = base64.b64decode(p.recvline().decode())
plain_len = len(flag)
plain1 = b"A" * plain_len
p.recvuntil(b"your_text1:")
p.sendline(plain1)
p.recvuntil(b"tag1: ")
tag1 = base64.b64decode(p.recvline().decode())
plain2 = b"B" * plain_len
p.recvuntil(b"your_text2:")
p.sendline(plain2)
p.recvuntil(b"tag2: ")
tag2 = base64.b64decode(p.recvline().decode())
# plain1 xor plain2 == cipher1 xor cipher2
res = gcm_nonce_reuse(plain1, tag1, plain2, tag2)
for H in res:
H_gf = to_gf(bytes_to_long(H))
# compute J0_enc
T = 0
for i in range((len(flag) + 15) // 16):
part = flag_cipher[i * 16 : (i + 1) * 16]
# update
T = T * H_gf + to_gf(bytes_to_long(part) * (256 ** (16 - len(part))))
L = len(flag) * 8
T = T * H_gf + to_gf(L)
T = T * H_gf + to_gf(bytes_to_long(flag_tag))
J0_enc = T
for ch1 in string.printable:
for ch2 in string.printable:
# guess flag
cur_flag = bytearray(flag)
cur_flag[0] = ord("F") # assume flag starting with F
cur_flag[16] = ord(ch1)
cur_flag[32] = ord(ch2)
# compute cipher2
temp = strxor(flag_cipher, cur_flag)
cipher2 = strxor(temp, plain2)
# verify tag2
T = 0
for i in range((len(plain2) + 15) // 16):
part = cipher2[i * 16 : (i + 1) * 16]
# update
T = T * H_gf + to_gf(bytes_to_long(part) * (256 ** (16 - len(part))))
L = len(plain2) * 8
T = T * H_gf + to_gf(L)
T = T * H_gf + J0_enc
if T == from_gf(bytes_to_long(tag2)):
print("Found flag", bytes(cur_flag))
exit(0)
Good to read: AES GCM and AES GCM-SIV mode.