Easy Random 6 WP

附件:

try:
    from Crypto.Util.number import getPrime, bytes_to_long
except:
    from Cryptodome.Util.number import getPrime, bytes_to_long
import random
import os

flag = os.getenv("GZCTF_FLAG") or "flag{fake_flag_for_testing}"
p = getPrime(256)
x_0 = random.randrange(1, p)
a = random.randrange(2, p)
b = random.randrange(1, p)
x = []
state = x_0
for i in range(100):
    state = (a * state + b) % p
    x.append(state // (2 ** 128))

print(a)
print(b)
print(p)
print(x[:10])
print(x[-1] ^ bytes_to_long(flag.encode()))

本题是 Easy Random 4 的增强版:给的不再是完整的 Linear Congruential Generator 的结果,而是给了它生成的值的低 64 位。这种问题叫做 Truncated Linear Congruential Generator Recovery,可以参考 Truncated LCG Seed Recovery 进行求解,思路如下:

  1. 已有的数列是 \(x_{n+1} = (ax_n + b) \bmod p\) 把它转化成另一个数数列:\(y_{n+1} = ay_n \bmod p\),方法是设置 \(y_i = x_i + b(a-1)^{-1}\)
  2. 由于 \(y_{n+1} = ay_n \bmod p\),所以 \(y_n = a^ny_0 \bmod p\),只剩下 \(y_0\) 一个未知数
  3. 因为 \(x_n\) 的高位已知,所以 \(y_n\) 的高位也是已知的,记 \(y_n\) 的估计值为 \(\hat{y_n}\)
  4. 为了求解 \(y_0\),把问题转化为一个 Shortest Vector Problem:寻找一个 \(y_0\),使得 \((y_0a \bmod p, y_0a^2 \bmod p, \cdots, y_0a^n \bmod p)\) 与向量 \((\hat{y_1} \bmod p, \hat{y_2} \bmod p, \cdots, \hat{y_n} \bmod p)\) 最接近
  5. SVP 本身并不支持模运算,所以为了实现模运算,构造如下的格,这些行向量的线性组合就可以得到想要的 \((y_0a \bmod p, y_0a^2 \bmod p, \cdots, y_0a^n \bmod p)\)
\[ L=\begin{pmatrix} a & a^2 & \cdots & a^n \\ & p & & \\ & & \ddots & \\ & & & p \\ \end{pmatrix} \]

求解 SVP 问题,就可以找到 \(y_0\),进而恢复出完整的随机数序列。

求解代码:

from pwn import *
from Cryptodome.Util.number import long_to_bytes
from fpylll import IntegerMatrix, LLL, CVP


def truncated_lcg(p: int, a: int, b: int, k: int, lsb: bool, x: list[int]) -> int:
    """
    Recover initial state of a Truncated Linear Congruential Genrator,
    using method from [Truncated LCG Seed Recovery](https://github.com/ajuelosemmanuel/Truncated_LCG_Seed_Recovery)

    Args:
        p (int): the prime modulo
        a (int): the multiplier
        b (int): the addend
        k (int): the number of bits given
        lsb (bool): True for LSB, False for MSB
        x (list[int]): k-bit MSB or LSB of x_1, x_2, ..., x_n, generated by x_{i+1} = (ax_i + b) mod p

    Returns:
        x_0: the recovered initial state of Truncated LCG
    """

    # step 1:
    # compute the value before shifting for approximation
    if lsb:
        I = pow(2**k, -1, p)
    else:
        I = 2 ** (p.bit_length() - k)
    y = [(el * I) % p for el in x]

    # step 2:
    # cancel b by adding b(a-1)^{-1}
    # x_{i+1} = (ax_i + b) \bmod p
    # z_i = x_i + b(a-1)^{-1}
    # z_{i+1} = (x_{i+1} + b(a-1)^{-1}) \bmod p
    #         = (ax_i + b + b(a-1)^{-1}) \bmod p
    #         = (ax_i + ab(a-1)^{-1}) \bmod p
    #         = (az_i) \bmod p
    if lsb:
        z = [
            (y[i] + b * pow(a - 1, -1, p) * pow(2**k, -1, p)) % p for i in range(len(y))
        ]
    else:
        z = [(y[i] + b * pow(a - 1, -1, p)) % p for i in range(len(y))]

    # step 3:
    # construct lattice
    # 1, a, a^2, a^3, ..., a^{n-1}
    # 0, p, ...
    # 0, 0,   p, ...
    # 0, 0,   0,   p, ...
    # ...
    # 0, 0,   0,   0, ...,       p
    size = len(z)
    matrix = [[0] * size for _ in range(size)]
    for i in range(size):
        matrix[0][i] = pow(a, i, p)
        if i >= 1:
            matrix[i][i] = p

    # step 4:
    # find closest vector to z
    # z is approximately (z_1 mod p, z_1a mod p, z_1a^2 mod p, ...)
    L = IntegerMatrix.from_matrix(matrix)
    reduced = LLL.reduction(L)
    Xi_I = CVP.closest_vector(reduced, z, method="fast")

    # step 5:
    # x_1 is the first element of the closest vector
    x_1 = Xi_I[0] % p
    if lsb:
        # recover seed before scaling
        x_1 = (x_1 * (2**k)) % p
    # drop the extra coefficient of b(a-1)^{-1}
    x_1 = (x_1 - b * pow(a - 1, -1, p)) % p

    # step 6:
    # compute x_0 = (x_1 - b) * a^{-1} \bmod p
    x_0 = (x_1 - b) * pow(a, -1, p) % p

    return x_0

io = process(["python3", "main.py"])
a = eval(io.recvline().decode())
b = eval(io.recvline().decode())
p = eval(io.recvline().decode())
x = eval(io.recvline().decode())
x_0 = truncated_lcg(p, a, b, 128, False, x)

x = []
state = x_0
for i in range(100):
    state = (a * state + b) % p
    x.append(state)

flag = eval(io.recvline().decode()) ^ (x[-1] // (2 ** 128))
print(long_to_bytes(flag))