ふるつき

v(*'='*)v かに

Plaid CTF 2022 pressure writeup

この土日はPlaid CTF 2022に出ていました。久しぶりにかなり頑張って張り付いた結果一問解くことができたのでwriteupです。

Plaid CTFは歴史ある高難易度CTFとして知られていて問題を解けて嬉しい。

overview

ed25519を用いた一種のDiffie Hellman鍵共有のようなプロトコルが実装されています。pynaclを使った実装になっていて丁寧だけど読みづらい。

from nacl.bindings.crypto_scalarmult import (
  crypto_scalarmult_ed25519_noclamp,
  crypto_scalarmult_ed25519_base_noclamp,
)
from nacl.bindings.crypto_core import (
  crypto_core_ed25519_scalar_mul,
  crypto_core_ed25519_scalar_reduce,
  crypto_core_ed25519_is_valid_point,
  crypto_core_ed25519_NONREDUCEDSCALARBYTES,
  crypto_core_ed25519_BYTES
)
import struct
import os
import ast
import hashlib
import random

def sha512(b):
  return hashlib.sha512(b).digest()

CONST = 4096

SECRET_LEN = int(random.randint(128, 256))
SECRET = [random.randint(1, 255) for i in range(SECRET_LEN)]

with open('flag', 'r') as f:
  FLAG = f.read()

def hsh(s):
  h = sha512(s)
  assert len(h) == crypto_core_ed25519_NONREDUCEDSCALARBYTES
  return crypto_scalarmult_ed25519_base_noclamp(crypto_core_ed25519_scalar_reduce(h))


def generate_secret_set(r):
  s = set()
  for (i, c) in enumerate(SECRET):
    s.add(hsh(bytes(str(i + 25037 * r * c).strip('L').encode('utf-8'))))
  return s


def genr():
  i = 0
  while i == 0:
    i, = struct.unpack('<I', os.urandom(4))
  return i


def handle_client1():
  print("Let's see if we share anything! You be the initiator this time.")
  r = genr()
  s = generate_secret_set(r)
  for k in range(1, CONST):
    s.add(hsh(bytes(str(k + CONST * (r % k)).strip('L').encode('utf-8'))))
  b = crypto_core_ed25519_scalar_reduce(os.urandom(crypto_core_ed25519_NONREDUCEDSCALARBYTES))
  server_s = set(crypto_scalarmult_ed25519_noclamp(b, e) for e in s)

  client_s = set()
  print("Send your data!")
  got = ast.literal_eval(input())
  for e in got:
    if not crypto_core_ed25519_is_valid_point(e):
      print("Bad client!")
      exit()

    client_s.add(e)

  server_combined_client = set(
    crypto_scalarmult_ed25519_noclamp(b, e) for e in client_s
  )

  client_resp1 = [e for e in server_combined_client]
  client_resp2 = [e for e in server_s]
  random.shuffle(client_resp1)
  random.shuffle(client_resp2)
  print(repr(client_resp1))
  print(repr(client_resp2))

  return r, s

def handle_client2(r, s):
  print("Let's see if we share anything! I'll be the initiator this time.")
  b = crypto_core_ed25519_scalar_reduce(os.urandom(crypto_core_ed25519_NONREDUCEDSCALARBYTES))
  server_s = set(crypto_scalarmult_ed25519_noclamp(b, e) for e in s)
  to_client = [e for e in server_s]
  random.shuffle(to_client)
  print(repr(to_client))

  client_s = set()
  print("Send client points: ")
  got = ast.literal_eval(input())
  for e in got:
    if not crypto_core_ed25519_is_valid_point(e):
      print("Bad client!")
      exit()

    client_s.add(e)

  masked_s = set()
  print("Send masked server points: ")
  got = ast.literal_eval(input())
  for e in got:
    if not crypto_core_ed25519_is_valid_point(e):
      print("Bad client!")
      exit()

    masked_s.add(e)

  if len(masked_s) != len(server_s):
    print("Bad client!")
    exit()

  if masked_s & server_s:
    print("Bad client!")
    exit()

  masked_c = set(crypto_scalarmult_ed25519_noclamp(b, e) for e in client_s)
  if masked_c == masked_s:
    print(FLAG)
  else:
    print("Aw, we don't share anything.")

def main():
  r, s = handle_client1()
  handle_client2(r, s)

if __name__ == "__main__":
  main()

 rは32bit(!)の乱数で、 s

 s = \lbrace  sha512( i + 25037rc_i )*G \rbrace + \lbrace sha512(i + 4096*(r \mod i))*G \rbrace

という集合です。最初にed25519上の点の集合 Aを送り、その後256bit程度の乱数 b, b'を用いた \lbrace b*P | P \in A \rbrace \lbrace b*P | P \in s \rbrace \lbrace b'*P | P \in s \rbrace がもらえます。 \lbrace b'*P | P \in B \rbrace = C となるような点の集合 B, Cを求めることができればフラグを取得できます。ただし C \ne sかつ |C| = |s|である必要があります。

方針

最終的な目的を達成するためには sの具体的な値を知ることが不可欠なので、まずは sを求めたいというのが自然な発送に見えます。そして rが4バイトと小さいことから、まずは rを求め、その後 s rを用いて算出することになりそうです。

rを求める

 rを求めるためには \lbrace b*P | P \in A \rbrace \lbrace b*P | P \in s \rbraceをうまく使う必要がありそうです。

 sのうち  \lbrace sha512(i + 4096*(r \mod i))*Gの部分は rが用いられていて未知の値は r \mod iだけで、未知と言っても候補の列挙が十分できる程度なのでこれが使えそうです。具体的には r \mod iから複数の iについて rの剰余を求めて中国剰余定理で rを復元できます。

たとえばこちら側で X = sha512(2 + 4096*0)*G, Y = sha512(2 + 4096*1)*Gを用意しておけば、 X sに含まれていれば r \equiv 0 \mod 2であり、 Y sに含まれていれば r \equiv 1 \mod 2であるということがわかるといった具合です。

ただし、たとえば A = \lbrace X, Y \rbraceとしても返ってきた \lbrace b*P | P \in A \rbraceでは未知数 bが掛けられていてどちらが b*Xでどちらが b*Yに相当するのかはわからなくなっています。2通り程度なら両方のパターンを試すこともできますが、全ての剰余についてそれをやっていては32bitの全探索をやっているのと同じです。

そこで  \lbrace b*P | P \in A \rbraceと送った点の集合との対応を探す必要があります。ここで、仮に i = 1の場合の点 P_1 = sha512(1 + 4096*0)*Gに対応する bP_1がわかっているとすると、 X = (sha512(1 + 4096*0))^{-1} * sha512(2 + 4096*0) * P_1 であることから  (sha512(1 + 4096*0))^{-1} * sha512(2 + 4096*0) * bP_1に一致する点があればそれが bXであると判断することができます。

  \lbrace b*P | P \in A \rbraceの全ての点についてその点が bP_1であると仮定して上記のように bX bYを求め、もしどちらかが \lbrace b*P | P \in s \rbrace内に存在すれば、仮定が正しいと判断することができ、 bP_1を見つけ出すことができます。 bP_1がわかれば、同じ要領で対応付けを行うことで  \lbrace b*P | P \in A \rbraceの点と Aの点の対応を求めることができます。

あとは  \lbrace b*P | P \in A \rbraceの点のうち \lbrace b*P | P \in s \rbraceに含まれるものから r \mod iがわかるので、 rを計算することができます。 rの計算には比較的大きい剰余が3つもあれば十分です。

points = [
    hsh(str(1 + CONST*0).encode()),
    hsh(str(2 + CONST*0).encode()),
    hsh(str(2 + CONST*1).encode()),
]

for k in [4091, 4093, 4905]:
    for rem in range(k):
        points.append(hsh(str(k + CONST*rem).encode()))

sock.sendlineafter("data!\n", repr(points))  # Aを送っている

xs = ast.literal_eval(sock.recvline().decode())  # {b*P | P in A}
ys = set(ast.literal_eval(sock.recvline().decode())) # {b*P | P in s}

inv = crypto_core_ed25519_scalar_invert(h_int(str(1 + 4096*0).encode()))
k_0 = crypto_core_ed25519_scalar_mul(inv, h_int(str(2 + 4096*0).encode()))
k_1 = crypto_core_ed25519_scalar_mul(inv, h_int(str(2 + 4096*1).encode()))

pairs = []
P1_idx = None
for i, P1 in enumerate(xs):
    bX = crypto_scalarmult_ed25519_noclamp(k_0, P1)
    bY = crypto_scalarmult_ed25519_noclamp(k_1, P1)

    if not (bX in ys or bY in ys):
        continue

    # bXあるいはbYが見つかったということは仮定が正しく、P1を見つけることができた
    P1_idx = i

    # get r mod k
    for k in [4091, 4093, 4095]:
        for rem in range(k):
            krem = crypto_core_ed25519_scalar_mul(
                inv,
                h_int(str(k + CONST*rem).encode()),
            )
            bkremG = crypto_scalarmult_ed25519_noclamp(krem, P1)
            if bkremG in ys:
                pairs.append((rem, k))
                break
assert len(pairs) == 3
r, _ = crt(pairs)
print("[+] found r: {}".format(r))

sを求める

 s = \lbrace  sha512( i + 25037rc_i )*G \rbrace + \lbrace sha512(i + 4096*(r \mod i))*G \rbrace

で、 c_iもさほど大きくないので、同じ要領で候補を全て bP_1から計算してそれが \lbrace b*P | P \in s \rbraceに含まれるかどうかを見れば候補が正しいかどうかを判定できます。

# calculate s
s = []
for i in range(256):
    for c in range(1, 256):
        k = crypto_core_ed25519_scalar_mul(
            inv,
            h_int(str(i + 25037*r*c).encode()),
        )
        bkG = crypto_scalarmult_ed25519_noclamp(k, P1)
        if bkG in ys:
            s.append(hsh(str(i + 25037*r*c).encode()))

for k in range(1, CONST):
    key = crypto_core_ed25519_scalar_mul(
        inv,
        h_int(str(k + CONST * (r % k)).encode()),
    )
    kG = crypto_scalarmult_ed25519_noclamp(key, P1)
    if kG in ys:
        s.append(hsh(str(k + CONST * (r % k)).encode()))

assert len(s) == len(ys)
print("[+] found s")

送るべきB, Cを求める

 \lbrace b'*P | P \in B \rbrace = Cかつ C \ne s,  |C| = |s|であるような B, Cを求めます。 B = s, C = \lbrace b'*P | P \in s \rbraceとできれば一番楽ですが、これは禁じられているので、かわりに B = \lbrace 2*P | P \in s \rbrace, C = \lbrace 2*Q | Q \in  \lbrace b'*P | P \in s \rbrace \rbraceとすればいいです。

exploit

完成したexploitがこちらになります。libsodiumの演算で点を定数倍するときのint -> bytesの変換に自身がなかったのでcrypto_core_ed25519_addを使っているのがおちゃめポイントです

from nacl.bindings.crypto_scalarmult import (
  crypto_scalarmult_ed25519_noclamp,
  crypto_scalarmult_ed25519_base_noclamp,
)
from nacl.bindings.crypto_core import (
  crypto_core_ed25519_add,
  crypto_core_ed25519_scalar_mul,
  crypto_core_ed25519_scalar_reduce,
  crypto_core_ed25519_scalar_invert,
  crypto_core_ed25519_is_valid_point,
  crypto_core_ed25519_NONREDUCEDSCALARBYTES,
  crypto_core_ed25519_BYTES
)
import hashlib
import struct
import os
from ptrlib import Socket, Process, crt
import ast

CONST = 4096

def sha512(b):
  return hashlib.sha512(b).digest()

def hsh(s):
  h = sha512(s)
  assert len(h) == crypto_core_ed25519_NONREDUCEDSCALARBYTES
  return crypto_scalarmult_ed25519_base_noclamp(crypto_core_ed25519_scalar_reduce(h))

def h_int(s):
  h = sha512(s)
  assert len(h) == crypto_core_ed25519_NONREDUCEDSCALARBYTES
  return crypto_core_ed25519_scalar_reduce(h)


# -- main
sock = Socket("nc pressure.chal.pwni.ng 1337")
p = Process(["bash", "-c", sock.recvline().decode()])
sock.sendline(p.recvlineafter("hashcash token: "))
p.close()
# sock = Socket("nc localhost 9999")

points = [
    hsh(str(1 + CONST*0).encode()),
    hsh(str(2 + CONST*0).encode()),
    hsh(str(2 + CONST*1).encode()),
]
for k in [4091, 4093, 4905]:
    for rem in range(k):
        points.append(hsh(str(k + CONST*rem).encode()))

sock.sendlineafter("data!\n", repr(points))

xs = ast.literal_eval(sock.recvline().decode())
ys = set(ast.literal_eval(sock.recvline().decode()))

inv = crypto_core_ed25519_scalar_invert(h_int(str(1 + 4096*0).encode()))
k_0 = crypto_core_ed25519_scalar_mul(inv, h_int(str(2 + 4096*0).encode()))
k_1 = crypto_core_ed25519_scalar_mul(inv, h_int(str(2 + 4096*1).encode()))

pairs = []
hG_idx = None
for i, hG in enumerate(xs):
    # assume hG as b*h(1)*G
    bk0G = crypto_scalarmult_ed25519_noclamp(k_0, hG)
    bk1G = crypto_scalarmult_ed25519_noclamp(k_1, hG)

    if not (bk0G in ys or bk1G in ys):
        continue
    # base is found
    hG_idx = i

    # get r mod k
    for k in [4091, 4093, 4095]:
        for rem in range(k):
            krem = crypto_core_ed25519_scalar_mul(
                inv,
                h_int(str(k + CONST*rem).encode()),
            )
            bkremG = crypto_scalarmult_ed25519_noclamp(krem, hG)
            if bkremG in ys:
                pairs.append((rem, k))
                break
assert len(pairs) == 3
r, _ = crt(pairs)
print("[+] found r: {}".format(r))

hG = xs[hG_idx]

# calculate s
s = []
for k in range(1, CONST):
    key = crypto_core_ed25519_scalar_mul(
        inv,
        h_int(str(k + CONST * (r % k)).encode()),
    )
    kG = crypto_scalarmult_ed25519_noclamp(key, hG)
    if kG in ys:
        s.append(hsh(str(k + CONST * (r % k)).encode()))


for i in range(256):
    for c in range(1, 256):
        k = crypto_core_ed25519_scalar_mul(
            inv,
            h_int(str(i + 25037*r*c).encode()),
        )
        bkG = crypto_scalarmult_ed25519_noclamp(k, hG)
        if bkG in ys:
            s.append(hsh(str(i + 25037*r*c).encode()))
assert len(s) == len(ys)
print("[+] found s")


# handle_client2
sock.recvlineafter("this time.")
b_server_s = ast.literal_eval(sock.recvline().decode())

sock.sendlineafter("client points: \n", repr([crypto_core_ed25519_add(e, e) for e in s]))  # 2*s
# masked_c = b*2*s
sock.sendlineafter("masked server points: \n", repr([crypto_core_ed25519_add(e, e) for e in b_server_s]))  # b*s + b*s = 2b*s

sock.interactive()

感想

主にpynaclとlibsodiumのAPIについての情報が少なかったこと、変数の命名に失敗して*1で時間を取られて結局4時間〜5時間くらい掛けました。もっと整ってたら半分の時間で解けたはず……。序盤はlibsodiumがわかりづらすぎてブチ切れてたけど終わってみたら面白かったです

*1:渡されたスクリプト命名が悪いんですけど