n³ writeup¶

Hi! Sorry for making nn chall. I've tried to make this challenge as easy as possible to solvable from most CTFers.
Hope you enjoy this challange!

In [ ]:
import onnx
import matplotlib.pyplot as plt
import numpy as np
model=onnx.load("net_obf.onnx")
def find_node_on_graph(name,graph=model.graph):
    for i in range(len(graph.node)):
        if graph.node[i].name==name:
            return graph.node[i]

        if graph.node[i].op_type=="If":
            for k in range(len(graph.node[i].attribute)):
                res=find_node_on_graph(name,graph.node[i].attribute[k].g)
                if res!=None:
                    return res
def find_initializer(name,graph=model.graph):
    for i in range(len(graph.initializer)):
        if graph.initializer[i].name==name:
            return onnx.numpy_helper.to_array(graph.initializer[i])
    
    for i in range(len(graph.node)):
        if graph.node[i].op_type=="If":
            for k in range(len(graph.node[i].attribute)):
                res=find_initializer(name,graph.node[i].attribute[k].g)
                if res is not None:
                    return res

Part 1¶

If_0_then_branch

AveragePool->Conv(but actually, just weighted sum)

So it's just the sum of the weights, and on one side, the weight is all 1, and on the other side it's only 0 and 1.
However, considering that the maximum value is 36864, it is quite specific because there is a condition that the difference in the sum of weights should be less than 300 and the sum value should be approximately 32400.
The second conv's weight e9ebd4158d366acc has a significant probability of the weight approaching the image.

In [ ]:
conv_node1=find_node_on_graph('510ca3cad4acbffd')
assert (find_initializer(conv_node1.input[1])==1).all()
conv_node2=find_node_on_graph('e9ebd4158d366acc')
pt1=find_initializer(conv_node2.input[1]).squeeze(0).transpose((1,2,0))
plt.imshow(pt1)
Out[ ]:
<matplotlib.image.AxesImage at 0x7fdb049cca60>

Part 2¶

If_0_then_branch__If_0_then_branch

Notation I means result from previous node.

normalize->mul(I*255)->add(I+A)->mul(I*137)->mod(I%257)->(concat(I|B)->conv)(actually I-B)

In [ ]:
F=((find_initializer('21add3e5c814986b')*pow(137,-1,257))-find_initializer('8e10760202cf3d3f'))%257
pt2=F.squeeze(0).transpose((1,2,0))
plt.imshow(pt2)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[ ]:
<matplotlib.image.AxesImage at 0x7fdb026324c0>

Part 3¶

If_0_then_branch__If_0_then_branch__If_0_then_branch에 있습니다.

Vote ensemble of three MLP-ish structure.

In [ ]:
import torch

def Linear(name):
    node=find_node_on_graph(name)    
    W=find_initializer(node.input[1])
    B=find_initializer(node.input[2])
    l=torch.nn.Linear(W.shape[1],W.shape[0])
    l.weight=torch.nn.Parameter(torch.tensor(W))
    l.bias=torch.nn.Parameter(torch.tensor(B))
    return l
    

tor1=torch.nn.Sequential(torch.nn.Flatten(),Linear("bd81aa784a4caf2d"),
    torch.nn.SiLU(),Linear("9b038baac1e62a86"),
    torch.nn.SiLU(),Linear("f4581a9ca6f0a910"),
).eval()
tor2=torch.nn.Sequential(
    torch.nn.Flatten(),Linear("931fe8d6e3b14f28"),
    torch.nn.GELU(),Linear("89f47d3cdfe769db"),
    torch.nn.GELU(),Linear("57e0c901e157786f"),
).eval()
In [ ]:
adv=torch.ones((1,1,48,192),requires_grad=True)
optim=torch.optim.Adam([adv],lr=0.1)
criterion=torch.nn.CrossEntropyLoss()
loss=1
epoch=1
while loss>0.03 and epoch<100:
    epoch+=1
    optim.zero_grad()
    advi=adv.clamp(0,1).tile((1,3,1,1))
    loss=criterion(tor1(advi),torch.tensor([1]))+criterion(tor2(advi),torch.tensor([1]))
    loss+=((1-adv[:,:,:8,:])**2).sum()+((1-adv[:,:,40:,:])**2).sum()
    loss.backward()
    print(str(loss.item())+" "*10,end="\r")
    optim.step()
0.04737356677651405           
In [ ]:
pt3=(advi.detach().numpy()>0.5).astype(np.float32).squeeze(0).transpose((1,2,0))
plt.imshow(pt3)
Out[ ]:
<matplotlib.image.AxesImage at 0x7fdb0259fa30>

Part 4¶

If_0_then_branch__If_0_then_branch__If_0_then_branch__If_0_then_branch

(mul(I*2)->sub(I-1))(0,1 to -1,1)->mul(I[:256]*I[256:512]*I[512:768]*X)

This whole operation itself is xnor.

In [ ]:
pts=np.concatenate([pt1,pt2,pt3],axis=1)
ptx=find_initializer("238652a900475522").squeeze(0).transpose((1,2,0))
pt4=((pts[:,:256,:]*2-1)*(pts[:,256:256*2,:]*2-1)*ptx)[:,128:,:]
plt.imshow(pt4)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[ ]:
<matplotlib.image.AxesImage at 0x7fda5989c610>
In [ ]:
flg=np.concatenate([pt1,pt2,pt3,pt4],axis=1)
plt.imshow(flg)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[ ]:
<matplotlib.image.AxesImage at 0x7fda5980f130>

Still, you could guess flag from part4. I think i should've make randomized flag or what to make guessing harder....