ctf-writeups

leaky-rsa

by wjaaaaaaat
Description

A few of my bits are leaky... but they're not important. That's why they're called the least significant!

nc leaky-rsa.chal.imaginaryctf.org 1337

Attachment:

#!/usr/local/bin/python3
import json
from Crypto.Util.number import getPrime
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from secrets import randbelow, token_bytes
from hashlib import sha256

with open('flag.txt') as f:
    flag = f.read()

p = getPrime(512)
q = getPrime(512)
n = p * q
e = 65537
d = pow(e, -1, (p-1)*(q-1))

key_m = randbelow(n)
key_c = pow(key_m, e, n)

key = sha256(str(key_m).encode()).digest()[:16]
iv = token_bytes(16)
ct = AES.new(key, AES.MODE_CBC, IV=iv).encrypt(pad(flag.encode(), 16))

print(json.dumps({'n': n, 'c': key_c, 'iv': iv.hex(), 'ct': ct.hex()}))

def get_bit(n, k):
    return (n >> k) % 2

for _ in range(1024):
    idx = randbelow(4)
    print(json.dumps({'idx': idx}))
    try:
        response = json.loads(input())
        c = response['c'] % n
        assert c != key_c
        m = pow(c, d, n)
        b = get_bit(m, idx)
    except (json.JSONDecodeError, TypeError, KeyError, ValueError, AssertionError):
        b = 2
    print(json.dumps({'b': b}))
print(key_m)

Initially, I thought it is a RSA parity oracle challenge: we can send c*(2**i**e)%n to the server, and know whether m*(2**i)%n is even or odd. Then, by binary search we can find m. However, we need 1024 such queries to recover m, but due to randbelow(4), we can only query 256 times.

Until DeepSeek found the print(key_m) on the last line. WTF, the key is given to us. We don’t need all those attacks, just wait for the key for decryption.

import json
from pwn import *
from Crypto.Cipher import AES
from hashlib import sha256
import tqdm

# context(log_level="debug")

p = remote(host="leaky-rsa.chal.imaginaryctf.org", port=1337)
# p = process(["python3", "chall.py"])
# skip proof-of-work line
print(p.recvline())
info = json.loads(p.recvline())
for _ in tqdm.tqdm(range(1024)):
    p.recvline()
    p.sendline()
    p.recvline()
key_m = int(p.recvline())
# print(info, key_m)

key = sha256(str(key_m).encode()).digest()[:16]
iv = bytes.fromhex(info['iv'])
ct = bytes.fromhex(info['ct'])
flag = AES.new(key, AES.MODE_CBC, IV=iv).decrypt(ct)
print(flag, key, iv, ct)

Flag: ictf{p13cin9_7h3_b1t5_t0g37her_3f0068c1b9be2547ada52a8020420fb0}.

Altough it is solved, I am curious whether we can find key_m via the 1024 queries…

DeepSeek hinted me about the print(key_m) in the source code, here is the full response:

To solve this CTF challenge, we need to recover the AES key `key_m` that was encrypted using RSA and then use it to decrypt the flag. The server provides an initial JSON object containing the RSA modulus `n`, the encrypted AES key `key_c`, the initialization vector `iv`, and the ciphertext `ct`. After that, the server enters a loop where it sends an index `idx` and expects a response containing a value `c`. However, if we send an invalid response, the server sets `b` to 2 and continues. After 1024 such interactions, the server prints `key_m` to its output. By leveraging this, we can simply send invalid responses to all queries, receive `key_m` after the loop, and then use it to decrypt the flag.

### Solution Code

import json
import socket
from Crypto.Cipher import AES
from hashlib import sha256

def main():
    # Replace 'server_address' and 'port' with the actual server details
    server_address = 'localhost'
    port = 12345

    # Connect to the server
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.connect((server_address, port))
    
    # Receive the initial data
    data = s.recv(4096).decode()
    initial = json.loads(data)
    n = initial['n']
    key_c = initial['c']
    iv = bytes.fromhex(initial['iv'])
    ct = bytes.fromhex(initial['ct'])
    
    # Process 1024 queries by sending invalid responses
    for _ in range(1024):
        # Receive the {'idx': idx} message
        data = s.recv(4096).decode()
        # Send an invalid response
        s.send(b'invalid\n')
        # Receive the {'b': 2} response
        data = s.recv(4096).decode()
    
    # After the loop, receive key_m
    data = s.recv(4096).decode().strip()
    try:
        key_m = int(data)
    except:
        # If there is an error, try to find key_m in the data
        # Sometimes there might be extra bytes, so we extract the number
        import re
        match = re.search(r'\d+', data)
        if match:
            key_m = int(match.group())
        else:
            print("Failed to get key_m")
            return
    
    # Compute the AES key from key_m
    key = sha256(str(key_m).encode()).digest()[:16]
    cipher = AES.new(key, AES.MODE_CBC, iv)
    flag = cipher.decrypt(ct)
    # Remove padding if necessary
    flag = flag.rstrip(b'\x00')
    print(flag.decode())

if __name__ == '__main__':
    main()

### Explanation
1. **Initial Connection**: The script connects to the server using socket programming and receives the initial JSON data containing `n`, `key_c`, `iv`, and `ct`.
2. **Handling Queries**: For each of the 1024 queries, the script receives the index `idx` from the server, sends an invalid response (`invalid`), and receives the server's response `{'b': 2}`.
3. **Receiving `key_m`**: After processing all queries, the server prints `key_m`, which the script reads from the socket.
4. **Decryption**: The script computes the AES key by hashing `key_m` with SHA-256 and uses it to decrypt the ciphertext `ct` using AES in CBC mode with the provided `iv`.
5. **Output**: The decrypted flag is printed after removing any padding.

This approach efficiently bypasses the need to interact with the queries meaningfully by exploiting the server's behavior of printing `key_m` after 1024 queries, allowing us to decrypt the flag easily.