@@ -37,17 +37,20 @@ def __init__(
37
37
self .gc1 = GCNConv (input_dim , hidden_dim )
38
38
self .gc2 = GCNConv (hidden_dim , hidden_dim )
39
39
self .graph_norm = GraphNorm (hidden_dim )
40
+ self .layer_norm = nn .LayerNorm (hidden_dim )
40
41
self .dropout = nn .Dropout (dropout )
41
42
self .pooling = pooling
42
43
self .fc = nn .Linear (hidden_dim , embedding_dim )
43
44
44
45
def forward (self , x , edge_index ):
45
46
x = F .relu (self .gc1 (x , edge_index ))
46
47
x = self .graph_norm (x )
48
+ x = self .layer_norm (x )
47
49
x = self .dropout (x )
48
50
49
51
x = F .relu (self .gc2 (x , edge_index ))
50
52
x = self .graph_norm (x )
53
+ x = self .layer_norm (x )
51
54
x = self .dropout (x )
52
55
53
56
# Глобальное агрегирование узловых признаков для получения представления всего графа
@@ -74,22 +77,25 @@ def __init__(self, input_dim, output_dim=16, dropout=0.5, pooling="max"):
74
77
75
78
self .input_dim = input_dim
76
79
self .output_dim = output_dim
80
+ self .hidden_dim = 64 # базовая размерность скрытого слоя
77
81
78
- self .gc1 = GCNConv (input_dim , 128 )
79
- self .gc2 = GCNConv (128 , 256 )
80
- self .gc3 = GCNConv (256 , 64 )
81
- self .layer_norm = nn .LayerNorm (64 )
82
- self .fc = nn .Linear (64 , output_dim )
83
- self .dropout = nn .Dropout (dropout )
84
- self .pooling = pooling
85
-
86
- self .key = nn .Linear (256 , 256 )
87
- self .query = nn .Linear (256 , 256 )
82
+ self .gc1 = GCNConv (input_dim , self .hidden_dim )
83
+ self .gc2 = GCNConv (self .hidden_dim , 256 )
84
+ self .gc3 = GCNConv (256 , 512 )
85
+ self .gc4 = GCNConv (512 , self .hidden_dim )
88
86
89
87
self .residual_proj = (
90
- nn .Linear (input_dim , 64 ) if input_dim != 64 else nn .Identity ()
88
+ nn .Linear (input_dim , self . hidden_dim ) if input_dim != self . hidden_dim else nn .Identity ()
91
89
)
92
90
91
+ self .layer_norm = nn .LayerNorm (self .hidden_dim )
92
+ self .dropout = nn .Dropout (dropout )
93
+ self .pooling = pooling
94
+
95
+ self .fc1 = nn .Linear (self .hidden_dim , self .hidden_dim )
96
+ self .fc_norm = nn .LayerNorm (self .hidden_dim )
97
+ self .fc2 = nn .Linear (self .hidden_dim , output_dim )
98
+
93
99
def forward (self , x , edge_index ):
94
100
residual = self .residual_proj (x )
95
101
@@ -99,22 +105,13 @@ def forward(self, x, edge_index):
99
105
x = F .leaky_relu (self .gc2 (x , edge_index ))
100
106
x = self .dropout (x )
101
107
102
- keys = self .key (x ) # [N, 128]
103
- queries = self .query (x ) # [N, 128]
104
- attn_scores = torch .mm (queries , keys .T ) # [N, N]
105
-
106
- row , col = edge_index
107
- mask = torch .zeros_like (attn_scores )
108
- mask [row , col ] = 1
109
- attn_scores = attn_scores * mask
110
- attn_scores = F .softmax (attn_scores , dim = - 1 )
111
-
112
- x = torch .mm (attn_scores , x ) # [N, N] * [N, 128] → [N, 128]
113
-
114
108
x = F .leaky_relu (self .gc3 (x , edge_index ))
115
109
x = self .dropout (x )
116
110
117
- x = self .layer_norm (x + residual ) # LayerNorm
111
+ x = F .leaky_relu (self .gc4 (x , edge_index ))
112
+ x = self .dropout (x )
113
+
114
+ x = self .layer_norm (x + residual )
118
115
119
116
if self .pooling == "max" :
120
117
x = torch .max (x , dim = 0 ).values
@@ -125,9 +122,13 @@ def forward(self, x, edge_index):
125
122
else :
126
123
raise ValueError ("Unsupported pooling method. Use 'max', 'mean' or 'sum'." )
127
124
128
- x = self .fc (x )
125
+ x = self .fc1 (x )
126
+ x = self .fc_norm (x )
127
+ x = F .leaky_relu (x )
128
+ x = self .fc2 (x )
129
+
129
130
if self .output_dim == 1 :
130
- x = nn . Sigmoid () (x )
131
+ x = torch . sigmoid (x )
131
132
return x
132
133
133
134
@@ -373,7 +374,6 @@ def train_model_diversity(
373
374
marker = "o" ,
374
375
label = "Valid Loss" ,
375
376
)
376
- plt .title ("Training and Validation Loss Over Epochs" )
377
377
plt .xlabel ("Epoch" )
378
378
plt .ylabel ("Loss" )
379
379
plt .grid (True )
@@ -481,7 +481,6 @@ def train_model_accuracy(
481
481
marker = "o" ,
482
482
label = "Valid Loss" ,
483
483
)
484
- plt .title ("Training and Validation Loss Over Epochs" )
485
484
plt .xlabel ("Epoch" )
486
485
plt .ylabel ("Loss" )
487
486
plt .ylim (0 , 0.002 )
0 commit comments