1
+ import numpy as np
2
+ import re
3
+ import torch
4
+ from sklearn .preprocessing import OneHotEncoder
5
+ from graphviz import Digraph
6
+ from IPython .display import display
7
+
8
+ DARTS_OPS = [
9
+ 'none' ,
10
+ 'max_pool_3x3' ,
11
+ 'avg_pool_3x3' ,
12
+ 'skip_connect' ,
13
+ 'sep_conv_3x3' ,
14
+ 'sep_conv_5x5' ,
15
+ 'dil_conv_3x3' ,
16
+ 'dil_conv_5x5' ,
17
+ ]
18
+
19
+ encoder = OneHotEncoder (handle_unknown = 'ignore' )
20
+ encoder = OneHotEncoder (handle_unknown = 'ignore' )
21
+ ops_array = np .array (DARTS_OPS ).reshape (- 1 , 1 )
22
+
23
+ DARTS_OPS_ONE_HOT = encoder .fit_transform (ops_array ).toarray ()
24
+
25
+ def extract_cells (arch_dict ):
26
+ normal_cell , reduction_cell = [], []
27
+ tmp_list = []
28
+
29
+ for key , value in arch_dict ["architecture" ].items ():
30
+ if key .startswith ("normal/" ) or key .startswith ("reduce/" ):
31
+ tmp_list .extend ([key , value ])
32
+
33
+ if len (tmp_list ) == 4 :
34
+ tmp_list .pop (2 )
35
+ if key .startswith ("normal/" ):
36
+ normal_cell .append (tmp_list )
37
+ else :
38
+ reduction_cell .append (tmp_list )
39
+ tmp_list = []
40
+
41
+ return normal_cell , reduction_cell
42
+
43
+ class Vertex :
44
+ def __init__ (self , op , in_channel , out_channel ):
45
+ self .op = op
46
+ self .in_channel = in_channel
47
+ self .out_channel = out_channel
48
+ self .op_one_hot = DARTS_OPS_ONE_HOT [DARTS_OPS .index (op )]
49
+
50
+ def __str__ (self ):
51
+ return f"Op: { self .op } | In: { self .in_channel } | Out: { self .out_channel } "
52
+ def __repr__ (self ):
53
+ return self .__str__ ()
54
+
55
+ class Graph (torch .utils .data .Dataset ):
56
+ def __init__ (self , model_dict , index = 0 ):
57
+ self .model_dict = model_dict
58
+ self .normal_cell , self .reduction_cell = extract_cells (model_dict )
59
+
60
+ self ._normal_graph = self .make_graph (self .normal_cell )
61
+ self ._reduction_graph = self .make_graph (self .reduction_cell )
62
+
63
+ self .normal_num_vertices , self .reduction_num_vertices = self .__len__ ()
64
+
65
+ self .graph = self .make_full_graph ()
66
+ self .index = index
67
+
68
+ def __len__ (self ):
69
+ max_normal_out = max (vertex .out_channel for vertex in self ._normal_graph )
70
+ max_reduction_out = max (vertex .out_channel for vertex in self ._reduction_graph )
71
+ return max_normal_out , max_reduction_out
72
+
73
+ def graph_size (self , graph ):
74
+ return max ((vertex .out_channel for vertex in graph ), default = 0 )
75
+
76
+ def make_full_graph (self ):
77
+ graph = [vertex for vertex in self ._normal_graph ]
78
+ graph = self ._unite_graphs (graph , self ._reduction_graph )
79
+
80
+ max_channel_diff , _ = self .__len__ ()
81
+ graph .append (Vertex ("none" , max_channel_diff * 2 + 1 , max_channel_diff * 2 + 1 ))
82
+
83
+ return graph
84
+
85
+ def _unite_graphs (self , graph1 , graph2 ):
86
+ graph1_size = self .graph_size (graph1 )
87
+ new_graph = [vertex for vertex in graph1 ]
88
+ for vertex in graph2 :
89
+ new_vertex = Vertex (
90
+ vertex .op ,
91
+ vertex .in_channel + graph1_size ,
92
+ vertex .out_channel + graph1_size ,
93
+ )
94
+ new_graph .append (new_vertex )
95
+
96
+ new_graph .sort (key = lambda vertex : (vertex .in_channel , vertex .out_channel ))
97
+
98
+ return new_graph
99
+
100
+ def make_graph (self , cell ):
101
+ graph = []
102
+ for value in cell :
103
+ in_channel = int (value [2 ][0 ])
104
+ out_channel = int (re .search (r"op_(\d+)_" , value [0 ]).group (1 ))
105
+ op = value [1 ]
106
+ graph .append (Vertex (op , in_channel , out_channel ))
107
+ graph .append (Vertex ("none" , 0 , 0 ))
108
+ graph .append (Vertex ("none" , 1 , 1 ))
109
+
110
+ graph .sort (key = lambda vertex : (vertex .in_channel , vertex .out_channel ))
111
+
112
+ return graph
113
+
114
+ def show_graph (self ):
115
+ adj_matrix , operations , _ = self .get_adjacency_matrix ()
116
+ graph_name = "Graph"
117
+
118
+ dot = Digraph (comment = graph_name , format = "png" )
119
+ dot .attr (rankdir = "TB" )
120
+
121
+ num_nodes = len (self .graph )
122
+
123
+ # Добавляем узлы с оригинальными метками
124
+ for idx , vertex in enumerate (self .graph ):
125
+ label = (
126
+ f"{{Op: { vertex .op } | "
127
+ f"In: { vertex .in_channel } | "
128
+ f"Out: { vertex .out_channel } }}"
129
+ )
130
+ dot .node (str (idx ), label = label , shape = "record" )
131
+
132
+ # Добавляем связи на основе матрицы смежности
133
+ for i in range (num_nodes ):
134
+ for j in range (num_nodes ):
135
+ if adj_matrix [i , j ] == 1 :
136
+ dot .edge (str (i ), str (j ))
137
+
138
+ display (dot )
139
+
140
+ def get_normal_graph (self ):
141
+ return self ._normal_graph
142
+
143
+ def get_reduction_graph (self ):
144
+ return self ._reduction_graph
145
+
146
+ def get_adjacency_matrix (self ):
147
+ adj_matrix_size = len (self .graph )
148
+ max_channel_diff , _ = self .__len__ ()
149
+ adj_matrix = np .zeros (shape = (adj_matrix_size , adj_matrix_size ))
150
+
151
+ operations = [vertex .op for vertex in self .graph ]
152
+ operations_one_hot = [vertex .op_one_hot for vertex in self .graph ]
153
+ for i in range (adj_matrix_size ):
154
+ for j in range (adj_matrix_size ):
155
+ if j == i :
156
+ continue
157
+ vertex_1 = self .graph [i ]
158
+ vertex_2 = self .graph [j ]
159
+
160
+ if (vertex_1 .out_channel == vertex_2 .in_channel ) and (
161
+ (
162
+ vertex_1 .in_channel <= max_channel_diff
163
+ and vertex_2 .out_channel <= max_channel_diff
164
+ )
165
+ or (
166
+ vertex_1 .in_channel >= max_channel_diff
167
+ and vertex_2 .out_channel >= max_channel_diff
168
+ )
169
+ ):
170
+
171
+ adj_matrix [i , j ] = 1
172
+
173
+ if ( # Добавляем ребро из c_k на вход следующей клетке
174
+ (vertex_1 .op == "none" )
175
+ and (vertex_2 .op == "none" )
176
+ and (vertex_1 .out_channel == 1 )
177
+ and (vertex_2 .in_channel == 6 )
178
+ ):
179
+ adj_matrix [i , j ] = 1
180
+
181
+ # Соединим оставшиеся узлы с выходом.
182
+
183
+ for i in range (adj_matrix_size ):
184
+ for j in range (adj_matrix_size ):
185
+ if j == i :
186
+ continue
187
+ vertex_1 = self .graph [i ]
188
+ vertex_2 = self .graph [j ]
189
+
190
+ if (np .all (adj_matrix [i , :] == 0 )) and (
191
+ (
192
+ (vertex_2 .op == "none" )
193
+ and (vertex_2 .in_channel == max_channel_diff )
194
+ and (vertex_1 .in_channel < max_channel_diff )
195
+ )
196
+ or (
197
+ (vertex_2 .out_channel == 2 * max_channel_diff + 1 )
198
+ and (vertex_1 .out_channel > max_channel_diff )
199
+ )
200
+ ):
201
+ adj_matrix [i , j ] = 1
202
+
203
+ adj_matrix = np .array (adj_matrix )
204
+ operations_one_hot = np .array (operations_one_hot )
205
+ return adj_matrix , operations , operations_one_hot
0 commit comments