Skip to content

Commit c946ae1

Browse files
committed
add gine signnet zinc
1 parent 921b8dd commit c946ae1

15 files changed

+1186
-0
lines changed

GINESignNetPyG/README.md

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
## SignNet and GINE implentations in PyTorch Geometric
2+
3+
For reproduction of our results on ZINC in the paper.
4+
5+
This is approximately the same as our code used for the Alchemy experiments.
6+
7+
## Setup
8+
9+
```
10+
# params
11+
# 10/6/2021, newest packages.
12+
ENV=pyg
13+
CUDA=11.1
14+
TORCH=1.9.1
15+
PYG=2.0.1
16+
17+
# create env
18+
conda create --name $ENV python=3.9 -y
19+
conda activate $ENV
20+
21+
# install pytorch
22+
conda install pytorch=$TORCH torchvision torchaudio cudatoolkit=$cuda -c pytorch -c nvidia -y
23+
24+
# install pyg2.0
25+
conda install pyg=$PYG -c pyg -c conda-forge -y
26+
27+
# install ogb
28+
pip install ogb
29+
30+
# install rdkit
31+
conda install -c conda-forge rdkit -y
32+
33+
# update yacs and tensorboard
34+
pip install yacs==0.1.8 --force # PyG currently use 0.1.6 which doesn't support None argument.
35+
pip install tensorboard
36+
pip install matplotlib
37+
38+
```
39+
40+
## Run
41+
```
42+
python -m train.zinc model.gnn_type SignNet
43+
```

GINESignNetPyG/core/__init__.py

Whitespace-only changes.

GINESignNetPyG/core/config.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from yacs.config import CfgNode as CN
2+
3+
def set_cfg(cfg):
4+
5+
# ------------------------------------------------------------------------ #
6+
# Basic options
7+
# ------------------------------------------------------------------------ #
8+
# Dataset name
9+
cfg.dataset = 'ZINC'
10+
# Additional num of worker for data loading
11+
cfg.num_workers = 8
12+
# Cuda device number, used for machine with multiple gpus
13+
cfg.device = 0
14+
# Additional string add to logging
15+
cfg.handtune = ''
16+
# Whether fix the running seed to remove randomness
17+
cfg.seed = None
18+
# Whether downsampling the dataset, used for large dataset for faster tuning
19+
cfg.downsample = False
20+
# version
21+
cfg.version = 'final'
22+
# task, for simulation datasets
23+
cfg.task = -1
24+
25+
# ------------------------------------------------------------------------ #
26+
# Training options
27+
# ------------------------------------------------------------------------ #
28+
cfg.train = CN()
29+
# Total graph mini-batch size
30+
cfg.train.batch_size = 128
31+
# Maximal number of epochs
32+
cfg.train.epochs = 100
33+
# Number of runs with random init
34+
cfg.train.runs = 3
35+
# Base learning rate
36+
cfg.train.lr = 0.001
37+
# number of steps before reduce learning rate
38+
cfg.train.lr_patience = 50
39+
# learning rate decay factor
40+
cfg.train.lr_decay = 0.5
41+
# L2 regularization, weight decay
42+
cfg.train.wd = 0.
43+
# Dropout rate
44+
cfg.train.dropout = 0.
45+
46+
# ------------------------------------------------------------------------ #
47+
# Model options
48+
# ------------------------------------------------------------------------ #
49+
cfg.model = CN()
50+
# GNN type used, see core.model_utils.pyg_gnn_wrapper for all options
51+
cfg.model.gnn_type = 'GINEConv' # change to list later
52+
# Hidden size of the model
53+
cfg.model.hidden_size = 128
54+
# Number of gnn layers (doesn't include #MLPs)
55+
cfg.model.num_layers = 4
56+
# Number of signnet layers
57+
cfg.model.num_layers_sign = 4
58+
# Pooling type for generaating graph/subgraph embedding from node embeddings
59+
cfg.model.pool = 'add'
60+
61+
return cfg
62+
63+
import os
64+
import argparse
65+
# Principle means that if an option is defined in a YACS config object,
66+
# then your program should set that configuration option using cfg.merge_from_list(opts) and not by defining,
67+
# for example, --train-scales as a command line argument that is then used to set cfg.TRAIN.SCALES.
68+
69+
def update_cfg(cfg, args_str=None):
70+
parser = argparse.ArgumentParser()
71+
parser.add_argument('--config', default="", metavar="FILE", help="Path to config file")
72+
# opts arg needs to match set_cfg
73+
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER,
74+
help="Modify config options using the command-line")
75+
76+
if isinstance(args_str, str):
77+
# parse from a string
78+
args = parser.parse_args(args_str.split())
79+
else:
80+
# parse from command line
81+
args = parser.parse_args()
82+
# Clone the original cfg
83+
cfg = cfg.clone()
84+
85+
# Update from config file
86+
if os.path.isfile(args.config):
87+
cfg.merge_from_file(args.config)
88+
89+
# Update from command line
90+
cfg.merge_from_list(args.opts)
91+
92+
return cfg
93+
94+
"""
95+
Global variable
96+
"""
97+
cfg = set_cfg(CN())

GINESignNetPyG/core/log.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Create a simple logger that can log training curves and final performance
2+
from torch.utils.tensorboard import SummaryWriter # tensorboard
3+
import logging, os, sys, shutil
4+
import datetime
5+
6+
def config_logger(cfg, OUT_PATH="results/", time=False):
7+
# time option is used for debugging different model architecture.
8+
data_name = cfg.dataset
9+
if cfg.handtune:
10+
data_name += f'-{cfg.handtune}'
11+
# generate config_string
12+
os.makedirs(os.path.join(OUT_PATH, cfg.version), exist_ok=True)
13+
config_string = f'T[{cfg.task}] GNN[{cfg.model.gnn_type}] L[{cfg.model.num_layers_sign}-{cfg.model.num_layers}] '\
14+
f'H[{cfg.model.hidden_size}] Pool[{cfg.model.pool}] '\
15+
f'Reg[{cfg.train.dropout}-{cfg.train.wd}] Seed[{cfg.seed}] GPU[{cfg.device}]'
16+
17+
# setup tensorboard writer
18+
writer_folder = os.path.join(OUT_PATH, cfg.version, data_name, config_string)
19+
if time:
20+
writer_folder = os.path.join(writer_folder, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
21+
if os.path.isdir(writer_folder): shutil.rmtree(writer_folder) # reset the folder, can also not reset
22+
writer = SummaryWriter(writer_folder)
23+
24+
# setup logging
25+
logger = logging.getLogger()
26+
logger.setLevel(logging.INFO)
27+
logger_filer = os.path.join(OUT_PATH, cfg.version, data_name, 'summary.log')
28+
fh = logging.FileHandler(logger_filer)
29+
fh.setLevel(logging.INFO)
30+
fh.setFormatter(logging.Formatter('%(message)s'))
31+
logger.addHandler(fh)
32+
33+
# redirect stdout print, better for large scale experiments
34+
os.makedirs(os.path.join('logs', data_name), exist_ok=True)
35+
# sys.stdout = open(f'logs/{data_name}/{config_string}.txt', 'w')
36+
37+
# log configuration
38+
print("-"*50)
39+
print(cfg)
40+
print("-"*50)
41+
print('Time:', datetime.datetime.now().strftime("%Y/%m/%d - %H:%M"))
42+
print(config_string)
43+
return writer, logger, config_string

GINESignNetPyG/core/model.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch_scatter import scatter
5+
import core.model_utils.pyg_gnn_wrapper as gnn_wrapper
6+
from core.model_utils.elements import MLP, DiscreteEncoder, Identity, BN
7+
from torch_geometric.nn.inits import reset
8+
9+
class GNN(nn.Module):
10+
# this version use nin as hidden instead of nout, resulting a larger model
11+
def __init__(self, nfeat_node, nfeat_edge, nhid, nout, nlayer, gnn_type, dropout=0, pooling='add', bn=BN, dos_bins=0, res=True):
12+
super().__init__()
13+
self.input_encoder = DiscreteEncoder(nhid-dos_bins) if nfeat_node is None else MLP(nfeat_node, nhid-dos_bins, 1)
14+
self.edge_encoders = nn.ModuleList([DiscreteEncoder(nhid) if nfeat_edge is None else MLP(nfeat_edge, nhid, 1) for _ in range(nlayer)])
15+
self.convs = nn.ModuleList([getattr(gnn_wrapper, gnn_type)(nhid, nhid, bias=not bn) for _ in range(nlayer)]) # set bias=False for BN
16+
self.norms = nn.ModuleList([nn.BatchNorm1d(nhid) if bn else Identity() for _ in range(nlayer)])
17+
self.output_encoder = MLP(nhid, nout, nlayer=2, with_final_activation=False, with_norm=False if pooling=='mean' else True)
18+
self.size_embedder = nn.Embedding(200, nhid)
19+
self.linear = nn.Linear(2*nhid, nhid)
20+
21+
if dos_bins > 0:
22+
self.ldos_encoder = MLP(dos_bins, nhid, nlayer=2, with_final_activation=True, with_norm=True)
23+
self.dos_encoder = MLP(dos_bins, nhid, nlayer=2, with_final_activation=False, with_norm=True)
24+
25+
self.pooling = pooling
26+
self.dropout = dropout
27+
self.res = res
28+
# for additional feature from (L)DOS
29+
self.dos_bins = dos_bins
30+
31+
def reset_parameters(self):
32+
self.input_encoder.reset_parameters()
33+
self.output_encoder.reset_parameters()
34+
self.size_embedder.reset_parameters()
35+
self.linear.reset_parameters()
36+
if self.dos_bins > 0:
37+
self.dos_encoder.reset_parameters()
38+
self.ldos_encoder.reset_parameters()
39+
for edge_encoder, conv, norm in zip(self.edge_encoders, self.convs, self.norms):
40+
edge_encoder.reset_parameters()
41+
conv.reset_parameters()
42+
norm.reset_parameters()
43+
44+
def forward(self, data, additional_x=None):
45+
x = self.input_encoder(data.x.squeeze())
46+
47+
# for PDOS
48+
if self.dos_bins > 0:
49+
x = torch.cat([x, data.pdos], dim=-1)
50+
# x += self.ldos_encoder(data.pdos)
51+
52+
if additional_x is not None:
53+
x = self.linear(torch.cat([x, additional_x], dim=-1))
54+
55+
ori_edge_attr = data.edge_attr
56+
if ori_edge_attr is None:
57+
ori_edge_attr = data.edge_index.new_zeros(data.edge_index.size(-1))
58+
59+
previous_x = x
60+
for edge_encoder, layer, norm in zip(self.edge_encoders, self.convs, self.norms):
61+
edge_attr = edge_encoder(ori_edge_attr)
62+
x = layer(x, data.edge_index, edge_attr)
63+
x = norm(x)
64+
x = F.relu(x)
65+
x = F.dropout(x, self.dropout, training=self.training)
66+
if self.res:
67+
x += previous_x
68+
previous_x = x
69+
70+
if self.pooling == 'mean':
71+
graph_size = scatter(torch.ones_like(x[:,0], dtype=torch.int64), data.batch, dim=0, reduce='add')
72+
x = scatter(x, data.batch, dim=0, reduce='mean') + self.size_embedder(graph_size)
73+
else:
74+
x = scatter(x, data.batch, dim=0, reduce='add')
75+
76+
if self.dos_bins > 0:
77+
x = x + self.dos_encoder(data.dos)
78+
x = self.output_encoder(x)
79+
return x
80+

GINESignNetPyG/core/model_utils/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch_geometric.nn import global_add_pool
5+
6+
BN = True
7+
# BN = False
8+
RUNNING_STAT = True
9+
10+
11+
class Identity(nn.Module):
12+
def __init__(self, *args, **kwargs):
13+
super(Identity, self).__init__()
14+
15+
def forward(self, input):
16+
return input
17+
18+
def reset_parameters(self):
19+
pass
20+
21+
class DiscreteEncoder(nn.Module):
22+
def __init__(self, hidden_channels, max_num_features=10, max_num_values=500): #10
23+
super().__init__()
24+
self.embeddings = nn.ModuleList([nn.Embedding(max_num_values, hidden_channels)
25+
for i in range(max_num_features)])
26+
27+
def reset_parameters(self):
28+
for embedding in self.embeddings:
29+
embedding.reset_parameters()
30+
31+
def forward(self, x):
32+
if x.dim() == 1:
33+
x = x.unsqueeze(1)
34+
out = 0
35+
for i in range(x.size(1)):
36+
out += self.embeddings[i](x[:, i])
37+
return out
38+
39+
class MLP(nn.Module):
40+
def __init__(self, nin, nout, nlayer=2, with_final_activation=True, with_norm=BN, bias=True, nhid=None):
41+
super().__init__()
42+
n_hid = nin if nhid is None else nhid
43+
self.layers = nn.ModuleList([nn.Linear(nin if i==0 else n_hid,
44+
n_hid if i<nlayer-1 else nout,
45+
bias=True if (i==nlayer-1 and not with_final_activation and bias) # TODO: revise later
46+
or (not with_norm) else False) # set bias=False for BN
47+
for i in range(nlayer)])
48+
self.norms = nn.ModuleList([nn.BatchNorm1d(n_hid if i<nlayer-1 else nout,track_running_stats=RUNNING_STAT) if with_norm else Identity()
49+
for i in range(nlayer)])
50+
self.nlayer = nlayer
51+
self.with_final_activation = with_final_activation
52+
self.residual = (nin==nout) ## TODO: test whether need this
53+
54+
def reset_parameters(self):
55+
for layer, norm in zip(self.layers, self.norms):
56+
layer.reset_parameters()
57+
norm.reset_parameters()
58+
59+
def forward(self, x):
60+
previous_x = x
61+
for i, (layer, norm) in enumerate(zip(self.layers, self.norms)):
62+
x = layer(x)
63+
if i < self.nlayer-1 or self.with_final_activation:
64+
x = norm(x)
65+
x = F.relu(x)
66+
67+
# if self.residual:
68+
# x = x + previous_x
69+
return x
70+
71+

0 commit comments

Comments
 (0)