Skip to content

Commit ffa1575

Browse files
committed
refactor GCN, add function for using surrogate functions
1 parent 61f5574 commit ffa1575

17 files changed

+33131
-7175
lines changed

code/GCN.ipynb

-1
This file was deleted.

code/GCN.py

+504
Large diffs are not rendered by default.

code/Graph.py

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import numpy as np
2+
import re
3+
import torch
4+
from sklearn.preprocessing import OneHotEncoder
5+
from graphviz import Digraph
6+
from IPython.display import display
7+
8+
DARTS_OPS = [
9+
'none',
10+
'max_pool_3x3',
11+
'avg_pool_3x3',
12+
'skip_connect',
13+
'sep_conv_3x3',
14+
'sep_conv_5x5',
15+
'dil_conv_3x3',
16+
'dil_conv_5x5',
17+
]
18+
19+
encoder = OneHotEncoder(handle_unknown='ignore')
20+
encoder = OneHotEncoder(handle_unknown='ignore')
21+
ops_array = np.array(DARTS_OPS).reshape(-1, 1)
22+
23+
DARTS_OPS_ONE_HOT = encoder.fit_transform(ops_array).toarray()
24+
25+
def extract_cells(arch_dict):
26+
normal_cell, reduction_cell = [], []
27+
tmp_list = []
28+
29+
for key, value in arch_dict["architecture"].items():
30+
if key.startswith("normal/") or key.startswith("reduce/"):
31+
tmp_list.extend([key, value])
32+
33+
if len(tmp_list) == 4:
34+
tmp_list.pop(2)
35+
if key.startswith("normal/"):
36+
normal_cell.append(tmp_list)
37+
else:
38+
reduction_cell.append(tmp_list)
39+
tmp_list = []
40+
41+
return normal_cell, reduction_cell
42+
43+
class Vertex:
44+
def __init__(self, op, in_channel, out_channel):
45+
self.op = op
46+
self.in_channel = in_channel
47+
self.out_channel = out_channel
48+
self.op_one_hot = DARTS_OPS_ONE_HOT[DARTS_OPS.index(op)]
49+
50+
def __str__(self):
51+
return f"Op: {self.op} | In: {self.in_channel} | Out: {self.out_channel}"
52+
def __repr__(self):
53+
return self.__str__()
54+
55+
class Graph(torch.utils.data.Dataset):
56+
def __init__(self, model_dict, index=0):
57+
self.model_dict = model_dict
58+
self.normal_cell, self.reduction_cell = extract_cells(model_dict)
59+
60+
self._normal_graph = self.make_graph(self.normal_cell)
61+
self._reduction_graph = self.make_graph(self.reduction_cell)
62+
63+
self.normal_num_vertices, self.reduction_num_vertices = self.__len__()
64+
65+
self.graph = self.make_full_graph()
66+
self.index = index
67+
68+
def __len__(self):
69+
max_normal_out = max(vertex.out_channel for vertex in self._normal_graph)
70+
max_reduction_out = max(vertex.out_channel for vertex in self._reduction_graph)
71+
return max_normal_out, max_reduction_out
72+
73+
def graph_size(self, graph):
74+
return max((vertex.out_channel for vertex in graph), default=0)
75+
76+
def make_full_graph(self):
77+
graph = [vertex for vertex in self._normal_graph]
78+
graph = self._unite_graphs(graph, self._reduction_graph)
79+
80+
max_channel_diff, _ = self.__len__()
81+
graph.append(Vertex("none", max_channel_diff * 2 + 1, max_channel_diff * 2 + 1))
82+
83+
return graph
84+
85+
def _unite_graphs(self, graph1, graph2):
86+
graph1_size = self.graph_size(graph1)
87+
new_graph = [vertex for vertex in graph1]
88+
for vertex in graph2:
89+
new_vertex = Vertex(
90+
vertex.op,
91+
vertex.in_channel + graph1_size,
92+
vertex.out_channel + graph1_size,
93+
)
94+
new_graph.append(new_vertex)
95+
96+
new_graph.sort(key=lambda vertex: (vertex.in_channel, vertex.out_channel))
97+
98+
return new_graph
99+
100+
def make_graph(self, cell):
101+
graph = []
102+
for value in cell:
103+
in_channel = int(value[2][0])
104+
out_channel = int(re.search(r"op_(\d+)_", value[0]).group(1))
105+
op = value[1]
106+
graph.append(Vertex(op, in_channel, out_channel))
107+
graph.append(Vertex("none", 0, 0))
108+
graph.append(Vertex("none", 1, 1))
109+
110+
graph.sort(key=lambda vertex: (vertex.in_channel, vertex.out_channel))
111+
112+
return graph
113+
114+
def show_graph(self):
115+
adj_matrix, operations, _ = self.get_adjacency_matrix()
116+
graph_name = "Graph"
117+
118+
dot = Digraph(comment=graph_name, format="png")
119+
dot.attr(rankdir="TB")
120+
121+
num_nodes = len(self.graph)
122+
123+
# Добавляем узлы с оригинальными метками
124+
for idx, vertex in enumerate(self.graph):
125+
label = (
126+
f"{{Op: {vertex.op} | "
127+
f"In: {vertex.in_channel} | "
128+
f"Out: {vertex.out_channel}}}"
129+
)
130+
dot.node(str(idx), label=label, shape="record")
131+
132+
# Добавляем связи на основе матрицы смежности
133+
for i in range(num_nodes):
134+
for j in range(num_nodes):
135+
if adj_matrix[i, j] == 1:
136+
dot.edge(str(i), str(j))
137+
138+
display(dot)
139+
140+
def get_normal_graph(self):
141+
return self._normal_graph
142+
143+
def get_reduction_graph(self):
144+
return self._reduction_graph
145+
146+
def get_adjacency_matrix(self):
147+
adj_matrix_size = len(self.graph)
148+
max_channel_diff, _ = self.__len__()
149+
adj_matrix = np.zeros(shape=(adj_matrix_size, adj_matrix_size))
150+
151+
operations = [vertex.op for vertex in self.graph]
152+
operations_one_hot = [vertex.op_one_hot for vertex in self.graph]
153+
for i in range(adj_matrix_size):
154+
for j in range(adj_matrix_size):
155+
if j == i:
156+
continue
157+
vertex_1 = self.graph[i]
158+
vertex_2 = self.graph[j]
159+
160+
if (vertex_1.out_channel == vertex_2.in_channel) and (
161+
(
162+
vertex_1.in_channel <= max_channel_diff
163+
and vertex_2.out_channel <= max_channel_diff
164+
)
165+
or (
166+
vertex_1.in_channel >= max_channel_diff
167+
and vertex_2.out_channel >= max_channel_diff
168+
)
169+
):
170+
171+
adj_matrix[i, j] = 1
172+
173+
if ( # Добавляем ребро из c_k на вход следующей клетке
174+
(vertex_1.op == "none")
175+
and (vertex_2.op == "none")
176+
and (vertex_1.out_channel == 1)
177+
and (vertex_2.in_channel == 6)
178+
):
179+
adj_matrix[i, j] = 1
180+
181+
# Соединим оставшиеся узлы с выходом.
182+
183+
for i in range(adj_matrix_size):
184+
for j in range(adj_matrix_size):
185+
if j == i:
186+
continue
187+
vertex_1 = self.graph[i]
188+
vertex_2 = self.graph[j]
189+
190+
if (np.all(adj_matrix[i, :] == 0)) and (
191+
(
192+
(vertex_2.op == "none")
193+
and (vertex_2.in_channel == max_channel_diff)
194+
and (vertex_1.in_channel < max_channel_diff)
195+
)
196+
or (
197+
(vertex_2.out_channel == 2 * max_channel_diff + 1)
198+
and (vertex_1.out_channel > max_channel_diff)
199+
)
200+
):
201+
adj_matrix[i, j] = 1
202+
203+
adj_matrix = np.array(adj_matrix)
204+
operations_one_hot = np.array(operations_one_hot)
205+
return adj_matrix, operations, operations_one_hot

code/create_GCN_dataset.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import pandas as pd
2+
import numpy as np
3+
import os
4+
from scipy.special import softmax
5+
import json
6+
from tqdm import tqdm
7+
8+
def load_json_from_directory(directory_path):
9+
json_data = []
10+
for root, _, files in (os.walk(directory_path)):
11+
for file in tqdm(files, desc="Processing JSON files"):
12+
if file.endswith('.json'):
13+
file_path = os.path.join(root, file)
14+
with open(file_path, 'r', encoding='utf-8') as f:
15+
try:
16+
data = json.load(f)
17+
json_data.append(data)
18+
except json.JSONDecodeError as e:
19+
print(f"Error decoding JSON from file {file_path}: {e}")
20+
return json_data
21+
22+
def apply_softmax_to_predictions(data):
23+
24+
for item in tqdm(data, desc="Applying softmax to predictions"):
25+
if "test_predictions" in item:
26+
predictions = np.array(item["test_predictions"])
27+
softmaxed = softmax(predictions, axis=1)
28+
item["test_predictions"] = softmaxed.tolist()
29+
30+
def save_dicts_as_json(data, output_dir):
31+
if not os.path.exists(output_dir):
32+
os.makedirs(output_dir)
33+
34+
for i, item in enumerate(tqdm(data, desc="Saving JSON files")):
35+
file_name = f"sample_{i:04d}.json"
36+
file_path = os.path.join(output_dir, file_name)
37+
with open(file_path, 'w', encoding='utf-8') as f:
38+
json.dump(item, f, ensure_ascii=False, indent=4)
39+
40+
def apply_argmax_to_predictions(data):
41+
for item in tqdm(data, desc="Applying argmax to predictions"):
42+
if "test_predictions" in item:
43+
predictions = np.array(item["test_predictions"])
44+
argmaxed = np.argmax(predictions, axis=1)
45+
item["test_predictions"] = argmaxed.tolist()
46+
47+
48+
dir_path = "dataset_logits_fixed"
49+
first_arch_dicts = load_json_from_directory(dir_path)
50+
51+
output_dir = "dataset_probs"
52+
apply_softmax_to_predictions(first_arch_dicts)
53+
save_dicts_as_json(first_arch_dicts, output_dir)
54+
55+
apply_argmax_to_predictions(first_arch_dicts)
56+
output_dir = "tmp_dataset"
57+
save_dicts_as_json(first_arch_dicts, output_dir)
58+
59+
second_arch_dicts = load_json_from_directory("second_dataset")
60+
61+
first_arch_dicts.extend(second_arch_dicts)
62+
output_dir = "third_dataset"
63+
save_dicts_as_json(first_arch_dicts, output_dir)

code/data_generator.ipynb

+29,513-7,010
Large diffs are not rendered by default.

code/dataset/arch_dicts.json

+1-1
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)