ふるつき

v(*'='*)v かに

ACTF 2022 writeup

zer0pts で ACTF に出たのでwriteupです。真面目に取り組んで8位でした。簡単な前2問はよくできていて面白かったです。後ろ2問は特にcrypto的に面白いことがない割に取り組むのが大変でした

Impossible RSA

from Crypto.Util.number import *
from Crypto.PublicKey import RSA

e = 65537
flag = b'ACTF{...}'

while True:
    p = getPrime(1024)
    q = inverse(e, p)
    if not isPrime(q):
        continue
    n = p * q;
    public = RSA.construct((n, e))
    with open("public.pem", "wb") as file:
        file.write(public.exportKey('PEM'))
    with open("flag", "wb") as file:
        file.write(long_to_bytes(pow(bytes_to_long(flag), e, n)))
    break

あまりにきれいな問題設定に、これ既出じゃないのかという驚きが。

RSA q = e^{-1} \mod pです。つまり適当な正整数 k \lt eをもってきて eq = 1 + kpがなりたつので、 en = peq = p(1 + kp) = p + kp^2です。 kを全探索すれば未知数は pだけなので、二次方程式を解けばよいです。

from Crypto.Util.number import getPrime, inverse, isPrime, long_to_bytes, bytes_to_long
from Crypto.PublicKey import RSA
from tqdm import tqdm
from gmpy2 import iroot

key = RSA.import_key(open("public.pem").read())

n = key.n
e = key.e


for k in tqdm(range(1, e)):
    a = k
    b = 1
    c = -n*e

    x, ok = iroot(b**2 - 4*a*c, 2)
    if not ok:
        continue

    if (b + x) % (2*a) == 0:
        p = int(abs((b + x) // (2*a)))
        break

    if (b - x) % (2*a) == 0:
        p = int(abs((b - x) // (2*a)))
        break

q = n // p
d = pow(e, -1, (p-1)*(q-1))

c = bytes_to_long(open("flag", "rb").read())
m = pow(c, d, n)

print(long_to_bytes(m))

RSA Leak

from sage.all import *
from secret import flag
from Crypto.Util.number import bytes_to_long


def leak(a, b):
    p = random_prime(pow(2, 64))
    q = random_prime(pow(2, 64))
    n = p*q
    e = 65537
    print(n)
    print((pow(a, e) + pow(b, e) + 0xdeadbeef) % n)


def gen_key():
    a = randrange(0, pow(2,256))
    b = randrange(0, pow(2,256))
    p = pow(a, 4)
    q = pow(b, 4)
    rp = randrange(0, pow(2,24))
    rq = randrange(0, pow(2,24))
    pp = next_prime(p+rp)
    qq = next_prime(q+rq)
    if pp % pow(2, 4) == (pp-p) % pow(2, 4) and qq % pow(2, 4) == (qq-q) % pow(2, 4):
        n = pp*qq
        rp = pp-p
        rq = qq-q
        return n, rp, rq
    
n, rp, rq = gen_key()
e = 65537
c = pow(bytes_to_long(flag), e, n)
print("n =", n)
print("e =", e)
print("c =", c)
print("=======leak=======")
leak(rp, rq)

'''
n = ...
e = 65537
c = ...
=======leak=======
122146249659110799196678177080657779971
90846368443479079691227824315092288065
'''

 p = a^4 + r_p, q = b^4 + r_q という[RSA] で、[$ n, e, c]の他に n_2, leak = r_p^e + r_q^e + 0xdeadbeef \mod n_2がもらえる。 n_2は十分小さいので素因数分解でき、 r_p, r_qは24bit程度なので全探索できます。 r_pが決まれば自動的に r_qが決まりますが、 n_2が64bitくらいなので、適当にやると r_qも64bitくらいになるところ、正しい r_pを選べていると r_qが24bit程度になるはずなのでわかります。

また、 n \approx a^4b^4 なので n^{1/4} = abがなりたちます。これで n = (a^4 + r_p)(b^4 + r_q), ab = a*bという2式があって、未知数が a, bの2つなので連立方程式を立てれば解けます

from tqdm import tqdm
from sympy.solvers import solve
from sympy import symbols
from gmpy2 import iroot
from Crypto.Util.number import long_to_bytes


n = 3183573836769699313763043722513486503160533089470716348487649113450828830224151824106050562868640291712433283679799855890306945562430572137128269318944453041825476154913676849658599642113896525291798525533722805116041675462675732995881671359593602584751304602244415149859346875340361740775463623467503186824385780851920136368593725535779854726168687179051303851797111239451264183276544616736820298054063232641359775128753071340474714720534858295660426278356630743758247422916519687362426114443660989774519751234591819547129288719863041972824405872212208118093577184659446552017086531002340663509215501866212294702743
e = 65537
c = 48433948078708266558408900822131846839473472350405274958254566291017137879542806238459456400958349315245447486509633749276746053786868315163583443030289607980449076267295483248068122553237802668045588106193692102901936355277693449867608379899254200590252441986645643511838233803828204450622023993363140246583650322952060860867801081687288233255776380790653361695125971596448862744165007007840033270102756536056501059098523990991260352123691349393725158028931174218091973919457078350257978338294099849690514328273829474324145569140386584429042884336459789499705672633475010234403132893629856284982320249119974872840

ab, _ = iroot(n, 4)
a, b = symbols("a, b", integer=True)

n2 = 122146249659110799196678177080657779971
leak = 90846368443479079691227824315092288065

p2 = 8949458376079230661
q2 = 13648451618657980711

d2 = pow(e, -1, (p2-1)*(q2-1))
x = leak - 0xdeadbeef

for rp in tqdm(range(0, 2**24)):
    rq = pow(x - pow(rp, e, n2), d2, n2)
    if rq < 2**24:
        print((rp, rq))

        y = n - (ab**4 + rp*rq)
        solutions = solve([a**4*rq + b**4*rp - y, ab - a*b])
        for sol in solutions:
            print(sol)
            a_ = int(sol[a])

            p = a_**4 + rp
            q = n // p

            d = pow(e, -1, (p-1)*(q-1))
            m = pow(c, d, n)
            print(long_to_bytes(m))

secure connection

力尽きてきた。bluetoothかなにかの通信が実装されていて、パケットのダンプもあるけど一部データがXXに置き換わったりしてます。 CRCがついてるのでこれに合うようにXXの部分を全探索して、shared numeric keyが24bit 程度しかないので全探索!

retros

Revするとこういう感じのVMであることがわかります

"""
VM:
    regs:
        0: pc
        1: memptr
        2: memptr
        3: value
        4: value
        5: global
        6: flag
    mem[32]: shuffled 32 values
    global: number
"""

pc = 0
ptr1 = 1
ptr2 = 2
val1 = 3
val2 = 4
glob = 5
flag = 6


def check_and_halt():
    """
    print flag if mem is sorted
    """
    return [0]


def add_g(idx):
    assert idx in [0, 1, 2, 5, 6]
    return [1, idx]


def sub_g(idx):
    assert idx in [0, 1, 2, 5, 6]
    return [2, idx]


def mv_from_mem(reg, ptr):
    """
    reg = mem[ptr]
    """
    assert reg in [3, 4]  # general register
    assert ptr in [1, 2]  # mem ptr register
    return [3, reg, ptr]


def mv_to_mem(ptr, reg):
    """
    mem[ptr] = reg
    """
    assert ptr in [1, 2]
    assert reg in [3, 4]
    return [4, ptr, reg]


def set_g(val):
    """
    global = val
    """
    assert 0 <= val < 256
    return [5, val >> 4, val & 0x0f]


def set_g_if(val):
    """
    if flag == 1:
        global = val
    """
    assert 0 <= val < 256
    return [6, val >> 4, val & 0x0f]


def memcmp():
    """
    if mem[ptr1] >= mem[ptr2]:
        flag = 1
    """
    return [7]


def cmp_ge(reg):
    """
    if reg >= global
     flag = 1
    """
    assert reg in [0, 1, 2, 3, 4, 5, 6]
    return [8, reg]


def mv_to_g(reg):
    """
    global = reg
    """
    return [9, reg]

31バイト以内でバイトコードを組み立てて送りつけて、32要素のメモリをソートしてcheck_and_haltを呼べばフラグが手に入ります。crypto要素はバイトコードはAES CBCで暗号化したものを送りつけないと行けないけど、鍵は知らないのでpadding oracle encryption attackを使ってオラクルを頼りに暗号文を構築するところで、あとはrev + miscです。

謎の31バイト制限&10000ステップ制限、貧弱なバイトコード、readlineの実装のせいで暗号化したバイトコードに0x0aは含められないなど様々な制約を突破するとフラグが手に入ります。このあたりのバイトコード組み立てとかの詳しいことは id:keymoon が書いてくれるはず。私はRevとpadding oracleしたのと、0x0aが含められなくて困るけどpadding oracle encryptionの最後のブロックは自由度があるのでそこをガチャすればいいって言うかかりをやりました

↓こういう感じで解ける

import random
from ptrlib import Process
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad

"""
VM:
    regs:
        0: pc
        1: memptr
        2: memptr
        3: value
        4: value
        5: global
        6: flag
    mem[32]: shuffled 32 values
    global: number
"""

pc = 0
ptr1 = 1
ptr2 = 2
val1 = 3
val2 = 4
glob = 5
flag = 6


def check_and_halt():
    """
    print flag if mem is sorted
    """
    return [0]


def add_g(idx):
    assert idx in [0, 1, 2, 5, 6]
    return [1, idx]


def sub_g(idx):
    assert idx in [0, 1, 2, 5, 6]
    return [2, idx]


def mv_from_mem(reg, ptr):
    """
    reg = mem[ptr]
    """
    assert reg in [3, 4]  # general register
    assert ptr in [1, 2]  # mem ptr register
    return [3, reg, ptr]


def mv_to_mem(ptr, reg):
    """
    mem[ptr] = reg
    """
    assert ptr in [1, 2]
    assert reg in [3, 4]
    return [4, ptr, reg]


def set_g(val):
    """
    global = val
    """
    assert 0 <= val < 256
    return [5, val >> 4, val & 0x0f]


def set_g_if(val):
    """
    if flag == 1:
        global = val
    """
    assert 0 <= val < 256
    return [6, val >> 4, val & 0x0f]


def memcmp():
    """
    if mem[ptr1] >= mem[ptr2]:
        flag = 1
    """
    return [7]


def cmp_ge(reg):
    """
    if reg >= global
     flag = 1
    """
    assert reg in [0, 1, 2, 3, 4, 5, 6]
    return [8, reg]


def mv_to_g(reg):
    """
    global = reg
    """
    return [9, reg]

instructions = []
jump_marker = {}


def jump1(name1):
    # 3: set_g
    # 2: add_g
    return [name1, None, None, None, None]  # placeholder


def jump(name1, name2):
    # 3: set_g
    # 3: set_g_if
    # 2: add_g/sub_g
    return [name2, None, None, name1, None, None, None, None]  # placeholder


def mark(name):
    jump_marker[name] = (len(instructions)) * 4



mark("outer_loop_begin")
# ptr1 += 1
instructions += set_g(1)
instructions += add_g(ptr1)

# ptr2 = 0
instructions += mv_to_g(ptr2)
instructions += sub_g(ptr2)

# flag = ptr1 >= 32
instructions += set_g(32)
instructions += cmp_ge(ptr1)
instructions += jump("check", "inner_loop_begin")

mark("inner_loop_begin")
# flag = mem[ptr1] >= mem[ptr2]
instructions += memcmp()
instructions += jump("inner_loop_end", "swap")

mark("swap")
instructions += mv_from_mem(val1, ptr1)
instructions += mv_from_mem(val2, ptr2)
instructions += mv_to_mem(ptr1, val2)
instructions += mv_to_mem(ptr2, val1)

mark("inner_loop_end")
# flag = ptr2 >= 31
# ptr2 += 1
# if (flag) break;
# else      continue;
instructions += set_g(1)
instructions += add_g(ptr2)

instructions += mv_to_g(ptr1)
instructions += cmp_ge(ptr2)
instructions += jump("outer_loop_begin", "inner_loop_begin")

mark("check")
instructions += check_and_halt()

g_if_ind = set()
# encode instructions
is_jump_forward = False
for i in range(len(instructions)):
    if isinstance(instructions[i], str):
        name = instructions[i]
        jump_to = jump_marker[name]
        if i+3 < jump_to:
            is_jump_forward = True
        else:
            is_jump_forward = False

        if isinstance(instructions[i+3], str):
            g_if_ind.add(i + 3)

        if i in g_if_ind:
            a, b, c = set_g_if(abs(jump_to - (i+5) * 4))
            instructions[i+0] = a
            instructions[i+1] = b
            instructions[i+2] = c
        else:
            a, b, c = set_g(abs(jump_to - (i+8) * 4))
            instructions[i+0] = a
            instructions[i+1] = b
            instructions[i+2] = c

    elif instructions[i] is None:
        if is_jump_forward:
            a, b = add_g(pc)
            instructions[i+0] = a
            instructions[i+1] = b
        else:
            a, b = sub_g(pc)
            instructions[i+0] = a
            instructions[i+1] = b

print(len(instructions), instructions)

i = 0
remain_insns = list(instructions)
res = ""
insts = [(1, "check"), (2, "add_g"), (2, "sub_g"), (3, "mv_from_mem"), (3, "mv_to_mem"), (3, "set_g"), (3, "set_g_if"), (1, "memcmp"), (2, "cmp_ge"), (2, "mv_to_g")]
while len(remain_insns) != 0:
    l, name = insts[remain_insns[0]]
    ops = remain_insns[:l]
    args = ops[1:]
    remain_insns = remain_insns[l:]
    
    op =   (str(ops[0])).rjust(3)
    arg0 = (str(ops[1]) if 2 <= l else "").rjust(3)
    arg1 = (str(ops[2]) if 3 <= l else "").rjust(3)

    if name in ["set_g", "set_g_if"]:
        disasm = f'{name}({ops[1] * 0x10 + ops[2]})'.ljust(18) + f'# {ops[1]}, {ops[2]}'
    else:
        disasm = f'{name}({", ".join(map(str, args))})'
    res += f'{str(i).rjust(3)}: {op} {arg0} {arg1} | {disasm}\n'
    i += l * 4

print(res)

if len(instructions) % 2 == 1:
    instructions.append(0)
payload = []
for i in range(0, len(instructions), 2):
    payload.append((instructions[i] << 4) | instructions[i+1])

if payload[-1] != 0:
    payload.append(0)
print(len(payload), payload)
assert 16 < len(payload) < 32
payload = pad(bytes(payload), 16)

# on remote, encode payload by using padding oracle encryption attack
iv = b"\0"*16
key = b"A" * 16
with open("./key", "wb") as f:
    f.write(key)

"""
aes = AES.new(mode=AES.MODE_CBC, key=key, iv=iv)
ticket = aes.encrypt(payload)

sock = Process("./retros")

input("WAIT> ")
sock.sendline(ticket)
sock.sendlineafter("fortune: ", b"\0"*0x10)
sock.interactive()
"""

import pwn, subprocess

LOCAL = False

# if LOCAL: io = pwn.process('./retros')
if LOCAL: io = pwn.remote("localhost", 8003)
else: io = pwn.remote('123.60.146.157', 9999)

if not LOCAL:
    print(io.recvline())
    dat = io.recvline().split(b'`')[1].decode().split()[2]
    print(dat)
    dat = subprocess.check_output(['hashcash','-mb26',dat])
    print(dat)
    dat = dat.decode().split(' ')[-1].strip()
    print(dat)
    io.sendline(dat)
    print(io.recvline())
    print(io.recvline())

def _do_oracle(ct2, token, j):
    if j < 0:
        return token, False
    
    candidate = []
    last_byte = None
    xval = ([0] * 16 + [16 - j] * (16 - j))[-16:]
    for i in range(256):
        token[j] = i
        send_token = pwn.xor(token, xval)
        if b'\n' in send_token:
            candidate.append(i)
            continue
        io.sendline(send_token + ct2)
        io.sendline(b'\x00' * 16)
    for i in range(256):
        if i in candidate: continue
        r = io.recvline()
        if b'not complete' in r:
            print(j, i, token)
            last_byte = i
    if last_byte is not None:
        token[j] = last_byte
        res, _ = _do_oracle(ct2, bytearray(token), j - 1)
        return res, True

    assert 1 <= len(candidate)
    print(f'[+] {j=} {token[j + 1]=} {len(candidate)=}')
    for i in candidate:
        token[j] = i
        res, confidence = _do_oracle(ct2, bytearray(token), j - 1)
        print(res)
        if confidence:
            print(f'[+] discoverd! {res=}, {confidence=}')
            return res, True
    return res, confidence

def do_oracle(ct2):
    token = bytearray(b'\x00' * 16)
    j = 15
    res, _ = _do_oracle(ct2, token, j)
    return res

init_ct2 = b'superneko'.ljust(16, b'\x00')

ct2 = do_oracle(init_ct2)
print(f'{ct2=}')
ct2 = pwn.xor(ct2, payload[16:])
ct3 = do_oracle(ct2)
print(f'{ct3=}')
ct3 = pwn.xor(ct3, payload[:16])

io.send(ct2 + init_ct2 + b'\n' + ct3 + b'\n')

io.interactive()