|
| 1 | +# built off of https://github.com/chrsmrrs/SpeqNets/blob/master/neural_graph/main_1_alchemy_10K.py |
| 2 | +# original author: Christopher Morris |
| 3 | +import sys |
| 4 | + |
| 5 | +sys.path.insert(0, '..') |
| 6 | +sys.path.insert(0, '.') |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import os.path as osp |
| 10 | +import time |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn as nn |
| 14 | +import torch.nn.functional as F |
| 15 | + |
| 16 | +from torch_geometric.data import InMemoryDataset, Data |
| 17 | +from torch_geometric.data import DataLoader |
| 18 | +from torch_geometric.datasets import TUDataset |
| 19 | + |
| 20 | +from baseline_gin import NetGINE |
| 21 | +from sign_net.transform import EVDTransform |
| 22 | +from sign_net.sign_net import SignNetGNN |
| 23 | + |
| 24 | +#model_name = 'gine' |
| 25 | +model_name = 'signnet' |
| 26 | +MIN_LR = 1e-6 |
| 27 | +PATIENCE = 20 |
| 28 | + |
| 29 | +def get_model(model_name): |
| 30 | + if model_name == 'gine': |
| 31 | + hidden_dim = 64 |
| 32 | + model = NetGINE(hidden_dim) |
| 33 | + elif model_name == 'signnet': |
| 34 | + hidden_dim = 108 |
| 35 | + model = SignNetGNN(6, 4, n_hid = hidden_dim, n_out=12, nl_signnet=8, nl_gnn=16, nl_rho=8, ignore_eigval=False, gnn_type='GINEConv') |
| 36 | + pass |
| 37 | + else: |
| 38 | + raise ValueError('invalid model name') |
| 39 | + return model.to(device) |
| 40 | + |
| 41 | + |
| 42 | +plot_all = [] |
| 43 | +results = [] |
| 44 | +results = [] |
| 45 | +results_log = [] |
| 46 | +for _ in range(5): |
| 47 | + |
| 48 | + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 49 | + plot_it = [] |
| 50 | + path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'datasets', "alchemy_full") |
| 51 | + |
| 52 | + infile = open("train_al_10.index", "r") |
| 53 | + for line in infile: |
| 54 | + indices_train = line.split(",") |
| 55 | + indices_train = [int(i) for i in indices_train] |
| 56 | + |
| 57 | + infile = open("val_al_10.index", "r") |
| 58 | + for line in infile: |
| 59 | + indices_val = line.split(",") |
| 60 | + indices_val = [int(i) for i in indices_val] |
| 61 | + |
| 62 | + infile = open("test_al_10.index", "r") |
| 63 | + for line in infile: |
| 64 | + indices_test = line.split(",") |
| 65 | + indices_test = [int(i) for i in indices_test] |
| 66 | + |
| 67 | + indices = indices_train |
| 68 | + indices.extend(indices_val) |
| 69 | + indices.extend(indices_test) |
| 70 | + |
| 71 | + transform = EVDTransform('sym') |
| 72 | + dataset = TUDataset(path, name="alchemy_full", transform=transform)[indices] |
| 73 | + print('Num points:', len(dataset)) |
| 74 | + |
| 75 | + mean = dataset.data.y.mean(dim=0, keepdim=True) |
| 76 | + std = dataset.data.y.std(dim=0, keepdim=True) |
| 77 | + dataset.data.y = (dataset.data.y - mean) / std |
| 78 | + mean, std = mean.to(device), std.to(device) |
| 79 | + |
| 80 | + train_dataset = dataset[0:10000] |
| 81 | + val_dataset = dataset[10000:11000] |
| 82 | + test_dataset = dataset[11000:] |
| 83 | + |
| 84 | + batch_size = 128 |
| 85 | + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
| 86 | + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) |
| 87 | + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) |
| 88 | + |
| 89 | + model = get_model(model_name) |
| 90 | + print(model) |
| 91 | + print('Trainable params:', sum([p.numel() for p in model.parameters() if p.requires_grad])) |
| 92 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) |
| 93 | + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', |
| 94 | + factor=0.5, patience=PATIENCE, |
| 95 | + min_lr=MIN_LR/2) |
| 96 | + |
| 97 | + |
| 98 | + def train(): |
| 99 | + model.train() |
| 100 | + loss_all = 0 |
| 101 | + |
| 102 | + lf = nn.L1Loss() |
| 103 | + for data in train_loader: |
| 104 | + data = data.to(device) |
| 105 | + optimizer.zero_grad() |
| 106 | + loss = lf(model(data), data.y) |
| 107 | + |
| 108 | + loss.backward() |
| 109 | + loss_all += loss.item() * data.num_graphs |
| 110 | + optimizer.step() |
| 111 | + return (loss_all / len(train_loader.dataset)) |
| 112 | + |
| 113 | + |
| 114 | + @torch.no_grad() |
| 115 | + def test(loader): |
| 116 | + model.eval() |
| 117 | + error = torch.zeros([1, 12]).to(device) |
| 118 | + |
| 119 | + for data in loader: |
| 120 | + data = data.to(device) |
| 121 | + error += ((data.y * std - model(data) * std).abs() / std).sum(dim=0) |
| 122 | + |
| 123 | + error = error / len(loader.dataset) |
| 124 | + error_log = torch.log(error) |
| 125 | + |
| 126 | + return error.mean().item(), error_log.mean().item() |
| 127 | + |
| 128 | + |
| 129 | + best_val_error = None |
| 130 | + for epoch in range(1, 1001): |
| 131 | + start_time = time.time() |
| 132 | + lr = scheduler.optimizer.param_groups[0]['lr'] |
| 133 | + loss = train() |
| 134 | + val_error, _ = test(val_loader) |
| 135 | + |
| 136 | + scheduler.step(val_error) |
| 137 | + if best_val_error is None or val_error <= best_val_error: |
| 138 | + test_error, test_error_log = test(test_loader) |
| 139 | + best_val_error = val_error |
| 140 | + elapsed = time.time() - start_time |
| 141 | + |
| 142 | + print('Epoch: {:03d}, LR: {:.7f}, Loss: {:.7f}, Validation MAE: {:.7f}, ' |
| 143 | + 'Test MAE: {:.7f}, Test log MAE: {:.7f}, Time (s): {:.2f}'.format(epoch, lr, loss, val_error, test_error, test_error_log, elapsed)) |
| 144 | + |
| 145 | + if lr <= MIN_LR: |
| 146 | + print("Converged.") |
| 147 | + break |
| 148 | + |
| 149 | + |
| 150 | + results.append(test_error) |
| 151 | + results_log.append(test_error_log) |
| 152 | + |
| 153 | + |
| 154 | +print('Trainable params:', sum([p.numel() for p in model.parameters() if p.requires_grad])) |
| 155 | +print("########################") |
| 156 | +print('\nTest MAE') |
| 157 | +print(results) |
| 158 | +results = np.array(results) |
| 159 | +print(results.mean(), results.std()) |
| 160 | + |
| 161 | +print('\n Test Log MAE') |
| 162 | +print(results_log) |
| 163 | +results_log = np.array(results_log) |
| 164 | +print(results_log.mean(), results_log.std()) |
0 commit comments