Skip to content

Commit 56f53fa

Browse files
committed
push alchemy
1 parent 1a53165 commit 56f53fa

File tree

1 file changed

+164
-0
lines changed

1 file changed

+164
-0
lines changed

Alchemy/main_alchemy.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# built off of https://github.com/chrsmrrs/SpeqNets/blob/master/neural_graph/main_1_alchemy_10K.py
2+
# original author: Christopher Morris
3+
import sys
4+
5+
sys.path.insert(0, '..')
6+
sys.path.insert(0, '.')
7+
8+
import numpy as np
9+
import os.path as osp
10+
import time
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
16+
from torch_geometric.data import InMemoryDataset, Data
17+
from torch_geometric.data import DataLoader
18+
from torch_geometric.datasets import TUDataset
19+
20+
from baseline_gin import NetGINE
21+
from sign_net.transform import EVDTransform
22+
from sign_net.sign_net import SignNetGNN
23+
24+
#model_name = 'gine'
25+
model_name = 'signnet'
26+
MIN_LR = 1e-6
27+
PATIENCE = 20
28+
29+
def get_model(model_name):
30+
if model_name == 'gine':
31+
hidden_dim = 64
32+
model = NetGINE(hidden_dim)
33+
elif model_name == 'signnet':
34+
hidden_dim = 108
35+
model = SignNetGNN(6, 4, n_hid = hidden_dim, n_out=12, nl_signnet=8, nl_gnn=16, nl_rho=8, ignore_eigval=False, gnn_type='GINEConv')
36+
pass
37+
else:
38+
raise ValueError('invalid model name')
39+
return model.to(device)
40+
41+
42+
plot_all = []
43+
results = []
44+
results = []
45+
results_log = []
46+
for _ in range(5):
47+
48+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49+
plot_it = []
50+
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'datasets', "alchemy_full")
51+
52+
infile = open("train_al_10.index", "r")
53+
for line in infile:
54+
indices_train = line.split(",")
55+
indices_train = [int(i) for i in indices_train]
56+
57+
infile = open("val_al_10.index", "r")
58+
for line in infile:
59+
indices_val = line.split(",")
60+
indices_val = [int(i) for i in indices_val]
61+
62+
infile = open("test_al_10.index", "r")
63+
for line in infile:
64+
indices_test = line.split(",")
65+
indices_test = [int(i) for i in indices_test]
66+
67+
indices = indices_train
68+
indices.extend(indices_val)
69+
indices.extend(indices_test)
70+
71+
transform = EVDTransform('sym')
72+
dataset = TUDataset(path, name="alchemy_full", transform=transform)[indices]
73+
print('Num points:', len(dataset))
74+
75+
mean = dataset.data.y.mean(dim=0, keepdim=True)
76+
std = dataset.data.y.std(dim=0, keepdim=True)
77+
dataset.data.y = (dataset.data.y - mean) / std
78+
mean, std = mean.to(device), std.to(device)
79+
80+
train_dataset = dataset[0:10000]
81+
val_dataset = dataset[10000:11000]
82+
test_dataset = dataset[11000:]
83+
84+
batch_size = 128
85+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
86+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
87+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
88+
89+
model = get_model(model_name)
90+
print(model)
91+
print('Trainable params:', sum([p.numel() for p in model.parameters() if p.requires_grad]))
92+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
93+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
94+
factor=0.5, patience=PATIENCE,
95+
min_lr=MIN_LR/2)
96+
97+
98+
def train():
99+
model.train()
100+
loss_all = 0
101+
102+
lf = nn.L1Loss()
103+
for data in train_loader:
104+
data = data.to(device)
105+
optimizer.zero_grad()
106+
loss = lf(model(data), data.y)
107+
108+
loss.backward()
109+
loss_all += loss.item() * data.num_graphs
110+
optimizer.step()
111+
return (loss_all / len(train_loader.dataset))
112+
113+
114+
@torch.no_grad()
115+
def test(loader):
116+
model.eval()
117+
error = torch.zeros([1, 12]).to(device)
118+
119+
for data in loader:
120+
data = data.to(device)
121+
error += ((data.y * std - model(data) * std).abs() / std).sum(dim=0)
122+
123+
error = error / len(loader.dataset)
124+
error_log = torch.log(error)
125+
126+
return error.mean().item(), error_log.mean().item()
127+
128+
129+
best_val_error = None
130+
for epoch in range(1, 1001):
131+
start_time = time.time()
132+
lr = scheduler.optimizer.param_groups[0]['lr']
133+
loss = train()
134+
val_error, _ = test(val_loader)
135+
136+
scheduler.step(val_error)
137+
if best_val_error is None or val_error <= best_val_error:
138+
test_error, test_error_log = test(test_loader)
139+
best_val_error = val_error
140+
elapsed = time.time() - start_time
141+
142+
print('Epoch: {:03d}, LR: {:.7f}, Loss: {:.7f}, Validation MAE: {:.7f}, '
143+
'Test MAE: {:.7f}, Test log MAE: {:.7f}, Time (s): {:.2f}'.format(epoch, lr, loss, val_error, test_error, test_error_log, elapsed))
144+
145+
if lr <= MIN_LR:
146+
print("Converged.")
147+
break
148+
149+
150+
results.append(test_error)
151+
results_log.append(test_error_log)
152+
153+
154+
print('Trainable params:', sum([p.numel() for p in model.parameters() if p.requires_grad]))
155+
print("########################")
156+
print('\nTest MAE')
157+
print(results)
158+
results = np.array(results)
159+
print(results.mean(), results.std())
160+
161+
print('\n Test Log MAE')
162+
print(results_log)
163+
results_log = np.array(results_log)
164+
print(results_log.mean(), results_log.std())

0 commit comments

Comments
 (0)