Hi! Sorry for making nn chall. I've tried to make this challenge as easy as possible to solvable from most CTFers.
Hope you enjoy this challange!
import onnx
import matplotlib.pyplot as plt
import numpy as np
model=onnx.load("net_obf.onnx")
def find_node_on_graph(name,graph=model.graph):
for i in range(len(graph.node)):
if graph.node[i].name==name:
return graph.node[i]
if graph.node[i].op_type=="If":
for k in range(len(graph.node[i].attribute)):
res=find_node_on_graph(name,graph.node[i].attribute[k].g)
if res!=None:
return res
def find_initializer(name,graph=model.graph):
for i in range(len(graph.initializer)):
if graph.initializer[i].name==name:
return onnx.numpy_helper.to_array(graph.initializer[i])
for i in range(len(graph.node)):
if graph.node[i].op_type=="If":
for k in range(len(graph.node[i].attribute)):
res=find_initializer(name,graph.node[i].attribute[k].g)
if res is not None:
return res
If_0_then_branch
AveragePool->Conv(but actually, just weighted sum)
So it's just the sum of the weights, and on one side, the weight is all 1, and on the other side it's only 0 and 1.
However, considering that the maximum value is 36864, it is quite specific because there is a condition that the difference in the sum of weights should be less than 300 and the sum value should be approximately 32400.
The second conv's weight e9ebd4158d366acc
has a significant probability of the weight approaching the image.
conv_node1=find_node_on_graph('510ca3cad4acbffd')
assert (find_initializer(conv_node1.input[1])==1).all()
conv_node2=find_node_on_graph('e9ebd4158d366acc')
pt1=find_initializer(conv_node2.input[1]).squeeze(0).transpose((1,2,0))
plt.imshow(pt1)
<matplotlib.image.AxesImage at 0x7fdb049cca60>
If_0_then_branch__If_0_then_branch
Notation I
means result from previous node.
normalize->mul(I*255
)->add(I+A
)->mul(I*137
)->mod(I%257
)->(concat(I|B
)->conv)(actually I-B
)
F=((find_initializer('21add3e5c814986b')*pow(137,-1,257))-find_initializer('8e10760202cf3d3f'))%257
pt2=F.squeeze(0).transpose((1,2,0))
plt.imshow(pt2)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7fdb026324c0>
If_0_then_branch__If_0_then_branch__If_0_then_branch
에 있습니다.
Vote ensemble of three MLP-ish structure.
import torch
def Linear(name):
node=find_node_on_graph(name)
W=find_initializer(node.input[1])
B=find_initializer(node.input[2])
l=torch.nn.Linear(W.shape[1],W.shape[0])
l.weight=torch.nn.Parameter(torch.tensor(W))
l.bias=torch.nn.Parameter(torch.tensor(B))
return l
tor1=torch.nn.Sequential(torch.nn.Flatten(),Linear("bd81aa784a4caf2d"),
torch.nn.SiLU(),Linear("9b038baac1e62a86"),
torch.nn.SiLU(),Linear("f4581a9ca6f0a910"),
).eval()
tor2=torch.nn.Sequential(
torch.nn.Flatten(),Linear("931fe8d6e3b14f28"),
torch.nn.GELU(),Linear("89f47d3cdfe769db"),
torch.nn.GELU(),Linear("57e0c901e157786f"),
).eval()
adv=torch.ones((1,1,48,192),requires_grad=True)
optim=torch.optim.Adam([adv],lr=0.1)
criterion=torch.nn.CrossEntropyLoss()
loss=1
epoch=1
while loss>0.03 and epoch<100:
epoch+=1
optim.zero_grad()
advi=adv.clamp(0,1).tile((1,3,1,1))
loss=criterion(tor1(advi),torch.tensor([1]))+criterion(tor2(advi),torch.tensor([1]))
loss+=((1-adv[:,:,:8,:])**2).sum()+((1-adv[:,:,40:,:])**2).sum()
loss.backward()
print(str(loss.item())+" "*10,end="\r")
optim.step()
0.04737356677651405
pt3=(advi.detach().numpy()>0.5).astype(np.float32).squeeze(0).transpose((1,2,0))
plt.imshow(pt3)
<matplotlib.image.AxesImage at 0x7fdb0259fa30>
If_0_then_branch__If_0_then_branch__If_0_then_branch__If_0_then_branch
(mul(I*2
)->sub(I-1
))(0,1 to -1,1)->mul(I[:256]*I[256:512]*I[512:768]*X
)
This whole operation itself is xnor.
pts=np.concatenate([pt1,pt2,pt3],axis=1)
ptx=find_initializer("238652a900475522").squeeze(0).transpose((1,2,0))
pt4=((pts[:,:256,:]*2-1)*(pts[:,256:256*2,:]*2-1)*ptx)[:,128:,:]
plt.imshow(pt4)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7fda5989c610>
flg=np.concatenate([pt1,pt2,pt3,pt4],axis=1)
plt.imshow(flg)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7fda5980f130>
Still, you could guess flag from part4. I think i should've make randomized flag or what to make guessing harder....