Skip to content

Commit 849fe95

Browse files
committed
fix train best models
1 parent 334c043 commit 849fe95

14 files changed

+2266
-515
lines changed

code/best_models.zip

5.58 KB
Binary file not shown.

code/gcn-training.ipynb

Lines changed: 694 additions & 514 deletions
Large diffs are not rendered by default.
Loading
Loading
-10.4 KB
Loading
-23.3 KB
Loading

code/graphics/surrogat_algo.xml

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
<mxfile host="app.diagrams.net">
2+
<diagram name="Surrogate Training Flow" id="d1">
3+
<mxGraphModel dx="1420" dy="728" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169">
4+
<root>
5+
<mxCell id="0"/>
6+
<mxCell id="1" parent="0"/>
7+
8+
<!-- Start -->
9+
<mxCell id="2" value="Start" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#ffcccc;" vertex="1" parent="1">
10+
<mxGeometry x="240" y="20" width="140" height="60" as="geometry"/>
11+
</mxCell>
12+
13+
<!-- Init model -->
14+
<mxCell id="3" value="Initialize model f_sim" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;" vertex="1" parent="1">
15+
<mxGeometry x="240" y="100" width="140" height="60" as="geometry"/>
16+
</mxCell>
17+
18+
<!-- Pick architecture -->
19+
<mxCell id="4" value="Pick an architecture" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;" vertex="1" parent="1">
20+
<mxGeometry x="240" y="180" width="140" height="60" as="geometry"/>
21+
</mxCell>
22+
23+
<!-- Sample pair -->
24+
<mxCell id="5" value="Choose positive and negative pair" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;" vertex="1" parent="1">
25+
<mxGeometry x="240" y="260" width="200" height="60" as="geometry"/>
26+
</mxCell>
27+
28+
<!-- Embeddings -->
29+
<mxCell id="6" value="Compute embeddings" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;" vertex="1" parent="1">
30+
<mxGeometry x="240" y="340" width="140" height="60" as="geometry"/>
31+
</mxCell>
32+
33+
<!-- Loss -->
34+
<mxCell id="7" value="Calculate triplet loss" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;" vertex="1" parent="1">
35+
<mxGeometry x="240" y="420" width="160" height="60" as="geometry"/>
36+
</mxCell>
37+
38+
<!-- Update -->
39+
<mxCell id="8" value="Update model weights" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;" vertex="1" parent="1">
40+
<mxGeometry x="240" y="500" width="160" height="60" as="geometry"/>
41+
</mxCell>
42+
43+
<!-- Repeat -->
44+
<mxCell id="9" value="Repeat steps as needed" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;" vertex="1" parent="1">
45+
<mxGeometry x="240" y="580" width="160" height="60" as="geometry"/>
46+
</mxCell>
47+
48+
<!-- Return -->
49+
<mxCell id="10" value="Return f_sim" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#ffcccc;" vertex="1" parent="1">
50+
<mxGeometry x="240" y="660" width="140" height="60" as="geometry"/>
51+
</mxCell>
52+
53+
<!-- Edges -->
54+
<mxCell id="11" style="endArrow=block;html=1;" edge="1" parent="1" source="2" target="3"><mxGeometry relative="1" as="geometry"/></mxCell>
55+
<mxCell id="12" style="endArrow=block;html=1;" edge="1" parent="1" source="3" target="4"><mxGeometry relative="1" as="geometry"/></mxCell>
56+
<mxCell id="13" style="endArrow=block;html=1;" edge="1" parent="1" source="4" target="5"><mxGeometry relative="1" as="geometry"/></mxCell>
57+
<mxCell id="14" style="endArrow=block;html=1;" edge="1" parent="1" source="5" target="6"><mxGeometry relative="1" as="geometry"/></mxCell>
58+
<mxCell id="15" style="endArrow=block;html=1;" edge="1" parent="1" source="6" target="7"><mxGeometry relative="1" as="geometry"/></mxCell>
59+
<mxCell id="16" style="endArrow=block;html=1;" edge="1" parent="1" source="7" target="8"><mxGeometry relative="1" as="geometry"/></mxCell>
60+
<mxCell id="17" style="endArrow=block;html=1;" edge="1" parent="1" source="8" target="9"><mxGeometry relative="1" as="geometry"/></mxCell>
61+
<mxCell id="18" style="endArrow=block;html=1;" edge="1" parent="1" source="9" target="10"><mxGeometry relative="1" as="geometry"/></mxCell>
62+
63+
</root>
64+
</mxGraphModel>
65+
</diagram>
66+
</mxfile>

code/graphics/surrogate_arch.png

40.3 KB
Loading

code/graphics/surrogate_arch.xml

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
<mxfile host="app.diagrams.net" agent="Mozilla/5.0 (X11; Linux x86_64; rv:138.0) Gecko/20100101 Firefox/138.0" version="27.0.2">
2+
<diagram name="GAT Architecture" id="gat-diagram">
3+
<mxGraphModel dx="1072" dy="577" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
4+
<root>
5+
<mxCell id="0" />
6+
<mxCell id="1" parent="0" />
7+
<mxCell id="input" value="&lt;div&gt;Input Layer &lt;br&gt;&lt;/div&gt;" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#DAE8FC;" parent="1" vertex="1">
8+
<mxGeometry x="130" y="240" width="160" height="40" as="geometry" />
9+
</mxCell>
10+
<mxCell id="gat1" value="GATv2Conv + Linear" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#E1D5E7;" parent="1" vertex="1">
11+
<mxGeometry x="120" y="300" width="180" height="80" as="geometry" />
12+
</mxCell>
13+
<mxCell id="norm1" value="&lt;div&gt;Leaky ReLU&lt;/div&gt;&lt;div&gt;GraphNorm &lt;br&gt;&lt;/div&gt;Dropout " style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#D5E8D4;" parent="1" vertex="1">
14+
<mxGeometry x="120" y="400" width="180" height="60" as="geometry" />
15+
</mxCell>
16+
<mxCell id="pool" value="Global Max/Mean/Sum Pooling" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#F8CECC;" parent="1" vertex="1">
17+
<mxGeometry x="130" y="480" width="160" height="40" as="geometry" />
18+
</mxCell>
19+
<mxCell id="fc1" value="&lt;div&gt;Linear&lt;/div&gt;" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#F5F5F5;" parent="1" vertex="1">
20+
<mxGeometry x="400" y="300" width="160" height="30" as="geometry" />
21+
</mxCell>
22+
<mxCell id="fc_norm" value="&lt;div&gt;Leaky ReLU&lt;/div&gt;&lt;div&gt;LayerNorm&lt;/div&gt;" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#D5E8D4;" parent="1" vertex="1">
23+
<mxGeometry x="400" y="360" width="160" height="40" as="geometry" />
24+
</mxCell>
25+
<mxCell id="fc2" value="Linear " style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#F5F5F5;" parent="1" vertex="1">
26+
<mxGeometry x="400" y="420" width="160" height="40" as="geometry" />
27+
</mxCell>
28+
<mxCell id="out" value="Sigmoid (if output_dim == 1)" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#FFF2CC;" parent="1" vertex="1">
29+
<mxGeometry x="380" y="485" width="200" height="40" as="geometry" />
30+
</mxCell>
31+
<mxCell id="e1" style="edgeStyle=orthogonalEdgeStyle;rounded=0;exitX=0.5;exitY=1;" parent="1" source="input" target="gat1" edge="1">
32+
<mxGeometry relative="1" as="geometry" />
33+
</mxCell>
34+
<mxCell id="e2" style="edgeStyle=orthogonalEdgeStyle;rounded=0;" parent="1" source="gat1" target="norm1" edge="1">
35+
<mxGeometry relative="1" as="geometry" />
36+
</mxCell>
37+
<mxCell id="e3" style="edgeStyle=orthogonalEdgeStyle;rounded=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;exitX=0.5;exitY=1;exitDx=0;exitDy=0;" parent="1" source="norm1" target="pool" edge="1">
38+
<mxGeometry relative="1" as="geometry">
39+
<mxPoint x="210" y="479.9999999999999" as="targetPoint" />
40+
</mxGeometry>
41+
</mxCell>
42+
<mxCell id="e11" style="edgeStyle=orthogonalEdgeStyle;rounded=0;" parent="1" source="fc1" target="fc_norm" edge="1">
43+
<mxGeometry relative="1" as="geometry" />
44+
</mxCell>
45+
<mxCell id="e12" style="edgeStyle=orthogonalEdgeStyle;rounded=0;" parent="1" source="fc_norm" target="fc2" edge="1">
46+
<mxGeometry relative="1" as="geometry" />
47+
</mxCell>
48+
<mxCell id="e13" style="edgeStyle=orthogonalEdgeStyle;rounded=0;" parent="1" source="fc2" target="out" edge="1">
49+
<mxGeometry relative="1" as="geometry" />
50+
</mxCell>
51+
<mxCell id="1mhNudxIzqTfyHtObyKU-1" value="&lt;div&gt;&lt;font style=&quot;font-size: 17px;&quot;&gt;&lt;/font&gt;&lt;/div&gt;" style="text;html=1;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" vertex="1" parent="1">
52+
<mxGeometry x="50" y="360" width="60" height="30" as="geometry" />
53+
</mxCell>
54+
<mxCell id="1mhNudxIzqTfyHtObyKU-2" value="&lt;div&gt;Output Layer&lt;/div&gt;" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#DAE8FC;" vertex="1" parent="1">
55+
<mxGeometry x="400" y="550" width="160" height="40" as="geometry" />
56+
</mxCell>
57+
<mxCell id="1mhNudxIzqTfyHtObyKU-3" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=1;exitDx=0;exitDy=0;" edge="1" parent="1" source="out" target="1mhNudxIzqTfyHtObyKU-2">
58+
<mxGeometry width="50" height="50" relative="1" as="geometry">
59+
<mxPoint x="340" y="470" as="sourcePoint" />
60+
<mxPoint x="390" y="420" as="targetPoint" />
61+
</mxGeometry>
62+
</mxCell>
63+
<mxCell id="1mhNudxIzqTfyHtObyKU-4" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="pool" target="fc1">
64+
<mxGeometry width="50" height="50" relative="1" as="geometry">
65+
<mxPoint x="340" y="510" as="sourcePoint" />
66+
<mxPoint x="390" y="460" as="targetPoint" />
67+
<Array as="points">
68+
<mxPoint x="210" y="560" />
69+
<mxPoint x="340" y="560" />
70+
<mxPoint x="340" y="270" />
71+
<mxPoint x="480" y="270" />
72+
</Array>
73+
</mxGeometry>
74+
</mxCell>
75+
<mxCell id="1mhNudxIzqTfyHtObyKU-8" value="" style="shape=curlyBracket;whiteSpace=wrap;html=1;rounded=1;labelPosition=left;verticalLabelPosition=middle;align=right;verticalAlign=middle;" vertex="1" parent="1">
76+
<mxGeometry x="100" y="290" width="20" height="180" as="geometry" />
77+
</mxCell>
78+
</root>
79+
</mxGraphModel>
80+
</diagram>
81+
</mxfile>

code/greedy-finding-best-models.ipynb

Lines changed: 622 additions & 1 deletion
Large diffs are not rendered by default.

code/test.ipynb

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "07171380",
7+
"metadata": {},
8+
"outputs": [
9+
{
10+
"name": "stdout",
11+
"output_type": "stream",
12+
"text": [
13+
"hello world\n"
14+
]
15+
}
16+
],
17+
"source": [
18+
"print(\"hello world\")"
19+
]
20+
}
21+
],
22+
"metadata": {
23+
"kernelspec": {
24+
"display_name": "Python 3 (ipykernel)",
25+
"language": "python",
26+
"name": "python3"
27+
},
28+
"language_info": {
29+
"codemirror_mode": {
30+
"name": "ipython",
31+
"version": 3
32+
},
33+
"file_extension": ".py",
34+
"mimetype": "text/x-python",
35+
"name": "python",
36+
"nbconvert_exporter": "python",
37+
"pygments_lexer": "ipython3",
38+
"version": "3.12.3"
39+
}
40+
},
41+
"nbformat": 4,
42+
"nbformat_minor": 5
43+
}

0 commit comments

Comments
 (0)