# 3kCTF 2021 - Digital writeup

まずスクリプトを眺めます。

from Crypto.Util.number import inverse
import hashlib
import os

rol = lambda val, r_bits, max_bits: (val << r_bits%max_bits) & (2**max_bits-1) | ((val & (2**max_bits-1)) >> (max_bits-(r_bits%max_bits)))

class Random():
def __init__(self, seed):
self.state = seed
self.bits = self.state.bit_length()

def next(self):
self.state ^= self.state << 76
self.state = rol(self.state, 32, self.bits)
self.state ^= self.state >> 104
self.state = rol(self.state, 20, self.bits)
self.state ^= self.state << 116
self.state = rol(self.state, 12, self.bits)
return self.state

def sign(message):
h = int(hashlib.sha256(message).hexdigest(), 16)
k = random.next()
r = pow(g, k, p) % q
s = inverse(k, q) * (h + x*r) % q
return (r, s)

def verify(message, r, s):
h = int(hashlib.sha256(message).hexdigest(), 16)
w = inverse(s, q)
u1 = h * w % q
u2 = r * w % q
v = (pow(g, u1, p) * pow(y, u2, p) % p ) % q
return v == r

random = Random(int(os.urandom(16).hex(), 16))
q = 0xc313d1a2bf3516a555c54875798a59a3d219ea76179b712886beec177263cec7

y = pow(g, x, p)

MSG2 = b'Susan decorated them.'
r1, s1 = sign(MSG1)
r2, s2 = sign(MSG2)
assert verify(MSG1, r1, s1)
assert verify(MSG2, r2, s2)
print ("y = ", y)
print ("r1 = ", r1)
print ("s1 = ", s1)
print ("r2 = ", r2)
print ("s2 = ", s2)

"""
y =  5624204323708883762857532177093000216929823277043458966645372679201025592769376026088466517180933057673841523705217308006821461505613041092599344214921758292705684588442147606413017270932589190682167865180010809895170865252326994825400330559172774619220024016595462686075240147992717554220738390033531322461011161893179173499597221230442911598574630392043521768535083211677909300720125573266145560294501586465872618003220096582182816143583907903491981432622413089428363003509954017358820731242558636829588468685964348899875705345969463735608144901602683917246879183938340727739626879210712728113625391485513623273477
r1 =  53670875511938152371853380079923244962420018116685861532166510031799178241334
s1 =  6408272343562387170976380346088007488778435579509591484022774936598892550745
r2 =  3869108664885100909066777013479452895407563047995298582999261416732594613401
s2 =  63203374922611188872786277873252648960215993219301469335034797776590362136211
"""


いくつか注目すべき点があります。与えられているのはという離散対数問題のインスタンスがひとつと、同じを使い、また離散対数問題のときの秘密鍵として用いたDSAのインスタンスがふたつです。また、DSAの署名時の乱数は、なにやら怪しげな乱数生成器によって与えられています。

ということでDSAのインスタンスに注目します。ここで、およびその他の変数はおよそ256bitの値であるのに対して、はseedが128bitであることから最大でも128bitの値であるとわかります。この値だけ比較的小さいので、LLLやそれを用いたCoppersmith法で求められる可能性があります。今回は2インスタンスしかありませんから、適当に制約をつないでmultivariate Coppersmithを試してみます。

まず式を立てます。を変形して です。これが2インスタンスありますから、それぞれ などと適当に番号を振って、式同士を減算すると です。これを満たすような適当な根は見つけられるでしょうか。

これは見つけられます。見つけられますが、この根が正しい（問題のインスタンスで使われたものと一致する）根ではない可能性があります。これは運ですが、今回のインスタンスについてはだめでした。そこで、さらなる条件を付け加えてやる必要があります。これはチームメイトの S3v3ru5 が教えてくれたのですが、の上位64bitとの下位64bitは共通です。これはたとえば下記のようなスクリプトを書いて確認することができます

from z3 import *

rol = lambda val, r_bits, max_bits: (val << r_bits%max_bits) & (2**max_bits-1) | ((val & (2**max_bits-1)) >> (max_bits-(r_bits%max_bits)))

class Random():
def __init__(self, seed):
self.state = seed
# self.bits = self.state.bit_length()

def next(self):
self.state ^= self.state << 76
self.state = RotateLeft(self.state, 32)
self.state ^= LShR(self.state, 104)
self.state = RotateLeft(self.state, 20)
self.state ^= self.state << 116
self.state = RotateLeft(self.state, 12)
return self.state

xs = [BitVec("x{}".format(i), 1) for i in range(16 * 8)]
x = Concat(*xs)
r = Random(x)
print(simplify(r.next()))


この性質を使うと128bit 2変数の多項式を64bit 3変数の多項式にできます。こちらの多項式を使ったMultivariate Coppersmithで正しいを得られ、を復元してフラグを得ることができました。

from Crypto.Util.number import *
from hashlib import sha256

# https://raw.githubusercontent.com/defund/coppersmith/master/coppersmith.sage
import itertools

def small_roots(f, bounds, m=1, d=None):
if not d:
d = f.degree()

R = f.base_ring()
N = R.cardinality()

f /= f.coefficients().pop(0)
f = f.change_ring(ZZ)

G = Sequence([], f.parent())
for i in range(m+1):
base = N^(m-i) * f^i
for shifts in itertools.product(range(d), repeat=f.nvariables()):
g = base * prod(map(power, f.variables(), shifts))
G.append(g)

B, monomials = G.coefficient_matrix()
monomials = vector(monomials)

factors = [monomial(*bounds) for monomial in monomials]
for i, factor in enumerate(factors):
B.rescale_col(i, factor)

B = B.dense_matrix().LLL()

B = B.change_ring(QQ)
for i, factor in enumerate(factors):
B.rescale_col(i, 1/factor)

H = Sequence([], f.parent().change_ring(QQ))
for h in filter(None, B*monomials):
H.append(h)
I = H.ideal()
if I.dimension() == -1:
H.pop()
elif I.dimension() == 0:
roots = []
for root in I.variety(ring=ZZ):
root = tuple(R(root[var]) for var in f.variables())
roots.append(root)
return roots

return []

q = 0xc313d1a2bf3516a555c54875798a59a3d219ea76179b712886beec177263cec7

y =  5624204323708883762857532177093000216929823277043458966645372679201025592769376026088466517180933057673841523705217308006821461505613041092599344214921758292705684588442147606413017270932589190682167865180010809895170865252326994825400330559172774619220024016595462686075240147992717554220738390033531322461011161893179173499597221230442911598574630392043521768535083211677909300720125573266145560294501586465872618003220096582182816143583907903491981432622413089428363003509954017358820731242558636829588468685964348899875705345969463735608144901602683917246879183938340727739626879210712728113625391485513623273477
r1 =  53670875511938152371853380079923244962420018116685861532166510031799178241334
s1 =  6408272343562387170976380346088007488778435579509591484022774936598892550745
r2 =  3869108664885100909066777013479452895407563047995298582999261416732594613401
s2 =  63203374922611188872786277873252648960215993219301469335034797776590362136211

MSG2 = b'Susan decorated them.'
h1 = int(sha256(MSG1).hexdigest(), 16)
h2 = int(sha256(MSG2).hexdigest(), 16)
r1inv = int(inverse_mod(r1, q))
r2inv = int(inverse_mod(r2, q))

s1inv = int(inverse_mod(s1, q))
s2inv = int(inverse_mod(s2, q))

PR.<k1u, k2l, k> = PolynomialRing(GF(q))

k1 = k1u*2^64 + k
k2 = k*2^64 + k2l

f = r1inv*(k1*s1 - h1) - r2inv*(k2*s2 - h2)
roots = small_roots(f, [2^64, 2^64, 2^64], m=2, d=2)
print(roots)
for root in roots:
k1u, k2l, k = [int(r) for r in root]
k1 = k1u*2^64 + k
k2 = k*2^64 + k2l

x1 = (r1inv*(k1*s1 - h1)) % q
x2 = (r2inv*(k2*s2 - h2)) % q

print(x1)
print(int(x1).to_bytes(100, "big"))
print(int(x2).to_bytes(100, "big"))
print(pow(g, int(x1), p) == y)
print(pow(g, int(x2), p) == y)