Skip to content

Commit a9cf893

Browse files
committed
initial push alchemy
1 parent 56f53fa commit a9cf893

15 files changed

+766
-0
lines changed

Alchemy/README.md

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
## Graph Regression Experiments on Alchemy
2+
3+
### Usage
4+
5+
To run the experiments for SignNet, use `python main_alchemy.py`.
6+
7+
### Implementation
8+
9+
Our SignNet model is implemented in PyTorch Geometric in the `sign_net` folder.
10+
11+
### Setup
12+
13+
Requirements are in `setup.sh`. Simply running `bash setup.sh` will usually make a conda environment called `torch-1-9` that works for these experiments, which you can then activate with `conda activate torch-1-9`.
14+
15+
You may have to edit the `CUDA` variable in `setup.sh` depending on the CUDA version of your GPUs. We use PyTorch 1.9 and PyTorch Geometric 2.0.1.
16+
17+
### Attribution
18+
19+
We built off of the SpeqNets repo by Christopher Morris et al. (no license) [[link](https://github.com/chrsmrrs/SpeqNets/blob/master/neural_graph/main_1_alchemy_10K.py)].
20+
21+
The Alchemy dataset is from "Alchemy: A Quantum Chemistry Dataset for Benchmarking AI Models" Chen et al. 2019 [[arXiv link](https://arxiv.org/abs/1906.09427)].

Alchemy/baseline_gin.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.nn import Sequential, Linear, ReLU
4+
import torch.nn.functional as F
5+
6+
from torch_geometric.nn import MessagePassing, Set2Set
7+
8+
9+
class GINConv(MessagePassing):
10+
def __init__(self, emb_dim, dim1, dim2):
11+
super(GINConv, self).__init__(aggr="add")
12+
13+
self.bond_encoder = Sequential(Linear(emb_dim, dim1), ReLU(), Linear(dim1, dim1))
14+
self.mlp = Sequential(Linear(dim1, dim1), ReLU(), Linear(dim1, dim2))
15+
self.eps = nn.Parameter(torch.Tensor([0]))
16+
17+
def forward(self, x, edge_index, edge_attr):
18+
edge_embedding = self.bond_encoder(edge_attr)
19+
out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
20+
return out
21+
22+
def message(self, x_j, edge_attr):
23+
return F.relu(x_j + edge_attr)
24+
25+
def update(self, aggr_out):
26+
return aggr_out
27+
28+
29+
class NetGINE(nn.Module):
30+
def __init__(self, dim):
31+
super(NetGINE, self).__init__()
32+
33+
num_features = 6
34+
dim = dim
35+
36+
self.conv1 = GINConv(4, num_features, dim)
37+
self.conv2 = GINConv(4, dim, dim)
38+
self.conv3 = GINConv(4, dim, dim)
39+
self.conv4 = GINConv(4, dim, dim)
40+
self.conv5 = GINConv(4, dim, dim)
41+
self.conv6 = GINConv(4, dim, dim)
42+
43+
self.set2set = Set2Set(1 * dim, processing_steps=6)
44+
45+
self.fc1 = Linear(2 * dim, dim)
46+
self.fc4 = Linear(dim, 12)
47+
48+
def forward(self, data):
49+
x = data.x
50+
51+
x_1 = F.relu(self.conv1(x, data.edge_index, data.edge_attr))
52+
x_2 = F.relu(self.conv2(x_1, data.edge_index, data.edge_attr))
53+
x_3 = F.relu(self.conv3(x_2, data.edge_index, data.edge_attr))
54+
x_4 = F.relu(self.conv4(x_3, data.edge_index, data.edge_attr))
55+
x_5 = F.relu(self.conv5(x_4, data.edge_index, data.edge_attr))
56+
x_6 = F.relu(self.conv6(x_5, data.edge_index, data.edge_attr))
57+
x = x_6
58+
x = self.set2set(x, data.batch)
59+
x = F.relu(self.fc1(x))
60+
x = self.fc4(x)
61+
return x

Alchemy/setup.sh

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# modified from setup in https://github.com/LingxiaoShawn/GNNAsKernel
2+
ENV=torch-1-9
3+
CUDA=11.1
4+
TORCH=1.9.1
5+
PYG=2.0.1
6+
7+
# create env
8+
conda create --name $ENV python=3.9 -y
9+
conda activate $ENV
10+
11+
# install pytorch
12+
conda install pytorch=$TORCH torchvision torchaudio cudatoolkit=$cuda -c pytorch -c nvidia -y
13+
14+
# install pyg2.0
15+
conda install pyg=$PYG -c pyg -c conda-forge -y
16+
17+
# update yacs and tensorboard
18+
pip install yacs==0.1.8 --force # PyG currently use 0.1.6 which doesn't support None argument.
19+
pip install matplotlib

Alchemy/sign_net/__init__.py

Whitespace-only changes.

Alchemy/sign_net/model.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch_scatter import scatter
5+
import sign_net.model_utils.pyg_gnn_wrapper as gnn_wrapper
6+
from sign_net.model_utils.elements import MLP, DiscreteEncoder, Identity, BN
7+
from torch_geometric.nn.inits import reset
8+
9+
class GNN(nn.Module):
10+
# this version use nin as hidden instead of nout, resulting a larger model
11+
def __init__(self, nfeat_node, nfeat_edge, nhid, nout, nlayer, gnn_type, dropout=0, pooling='add', bn=BN, res=True):
12+
super().__init__()
13+
self.input_encoder = DiscreteEncoder(nhid) if nfeat_node is None else MLP(nfeat_node, nhid, 1)
14+
self.edge_encoders = nn.ModuleList([DiscreteEncoder(nhid) if nfeat_edge is None else MLP(nfeat_edge, nhid, 1) for _ in range(nlayer)])
15+
self.convs = nn.ModuleList([getattr(gnn_wrapper, gnn_type)(nhid, nhid, bias=not bn) for _ in range(nlayer)]) # set bias=False for BN
16+
self.norms = nn.ModuleList([nn.BatchNorm1d(nhid) if bn else Identity() for _ in range(nlayer)])
17+
self.output_encoder = MLP(nhid, nout, nlayer=2, with_final_activation=False, with_norm=False if pooling=='mean' else True)
18+
#self.size_embedder = nn.Embedding(200, nhid)
19+
self.linear = nn.Linear(2*nhid, nhid)
20+
21+
22+
self.pooling = pooling
23+
self.dropout = dropout
24+
self.res = res
25+
26+
def reset_parameters(self):
27+
self.input_encoder.reset_parameters()
28+
self.output_encoder.reset_parameters()
29+
#self.size_embedder.reset_parameters()
30+
self.linear.reset_parameters()
31+
for edge_encoder, conv, norm in zip(self.edge_encoders, self.convs, self.norms):
32+
edge_encoder.reset_parameters()
33+
conv.reset_parameters()
34+
norm.reset_parameters()
35+
36+
def forward(self, data, additional_x=None):
37+
x = self.input_encoder(data.x.squeeze())
38+
39+
if additional_x is not None:
40+
x = self.linear(torch.cat([x, additional_x], dim=-1))
41+
42+
ori_edge_attr = data.edge_attr
43+
if ori_edge_attr is None:
44+
ori_edge_attr = data.edge_index.new_zeros(data.edge_index.size(-1))
45+
46+
previous_x = x
47+
for edge_encoder, layer, norm in zip(self.edge_encoders, self.convs, self.norms):
48+
edge_attr = edge_encoder(ori_edge_attr)
49+
x = layer(x, data.edge_index, edge_attr)
50+
x = norm(x)
51+
x = F.relu(x)
52+
x = F.dropout(x, self.dropout, training=self.training)
53+
if self.res:
54+
x = x + previous_x
55+
previous_x = x
56+
57+
if self.pooling == 'mean':
58+
graph_size = scatter(torch.ones_like(x[:,0], dtype=torch.int64), data.batch, dim=0, reduce='add')
59+
x = scatter(x, data.batch, dim=0, reduce='mean') # + self.size_embedder(graph_size)
60+
else:
61+
x = scatter(x, data.batch, dim=0, reduce='add')
62+
63+
x = self.output_encoder(x)
64+
return x
65+

Alchemy/sign_net/model_utils/__init__.py

Whitespace-only changes.
+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch_geometric.nn import global_add_pool
5+
6+
BN = True
7+
# BN = False
8+
RUNNING_STAT = True
9+
10+
11+
class Identity(nn.Module):
12+
def __init__(self, *args, **kwargs):
13+
super(Identity, self).__init__()
14+
15+
def forward(self, input):
16+
return input
17+
18+
def reset_parameters(self):
19+
pass
20+
21+
class DiscreteEncoder(nn.Module):
22+
def __init__(self, hidden_channels, max_num_features=10, max_num_values=6): #10
23+
super().__init__()
24+
self.embeddings = nn.ModuleList([nn.Embedding(max_num_values, hidden_channels)
25+
for i in range(max_num_features)])
26+
27+
def reset_parameters(self):
28+
for embedding in self.embeddings:
29+
embedding.reset_parameters()
30+
31+
def forward(self, x):
32+
if x.dim() == 1:
33+
x = x.unsqueeze(1)
34+
out = 0
35+
for i in range(x.size(1)):
36+
out += self.embeddings[i](x[:, i])
37+
return out
38+
39+
class MLP(nn.Module):
40+
def __init__(self, nin, nout, nlayer=2, with_final_activation=True, with_norm=BN, bias=True, nhid=None):
41+
super().__init__()
42+
n_hid = nin if nhid is None else nhid
43+
self.layers = nn.ModuleList([nn.Linear(nin if i==0 else n_hid,
44+
n_hid if i<nlayer-1 else nout,
45+
bias=True if (i==nlayer-1 and not with_final_activation and bias) # TODO: revise later
46+
or (not with_norm) else False) # set bias=False for BN
47+
for i in range(nlayer)])
48+
self.norms = nn.ModuleList([nn.BatchNorm1d(n_hid if i<nlayer-1 else nout,track_running_stats=RUNNING_STAT) if with_norm else Identity()
49+
for i in range(nlayer)])
50+
self.nlayer = nlayer
51+
self.with_final_activation = with_final_activation
52+
self.residual = (nin==nout) ## TODO: test whether need this
53+
54+
def reset_parameters(self):
55+
for layer, norm in zip(self.layers, self.norms):
56+
layer.reset_parameters()
57+
norm.reset_parameters()
58+
59+
def forward(self, x):
60+
previous_x = x
61+
for i, (layer, norm) in enumerate(zip(self.layers, self.norms)):
62+
x = layer(x)
63+
if i < self.nlayer-1 or self.with_final_activation:
64+
x = norm(x)
65+
x = F.relu(x)
66+
67+
# if self.residual:
68+
# x = x + previous_x
69+
return x
70+
71+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch_geometric.nn as gnn
4+
import torch.nn.functional as F
5+
from sign_net.model_utils.elements import Identity
6+
7+
class MaskedBN(nn.Module):
8+
def __init__(self, num_features):
9+
super().__init__()
10+
self.bn = nn.BatchNorm1d(num_features)
11+
def reset_parameters(self):
12+
self.bn.reset_parameters()
13+
def forward(self, x, mask=None):
14+
### apply BN to the last dim
15+
# x: n x k x d
16+
# mask: n x k
17+
if mask is None:
18+
return self.bn(x.transpose(1,2)).transpose(1,2)
19+
x[mask] = self.bn(x[mask])
20+
return x
21+
22+
class MaskedLN(nn.Module):
23+
def __init__(self, num_features):
24+
super().__init__()
25+
self.ln = nn.LayerNorm(num_features, eps=1e-6)
26+
def reset_parameters(self):
27+
self.ln.reset_parameters()
28+
def forward(self, x, mask=None):
29+
if mask is None:
30+
return self.ln(x)
31+
x[mask] = self.ln(x[mask])
32+
return x
33+
34+
class MaskedMLP(nn.Module):
35+
def __init__(self, nin, nout, nlayer=2, with_final_activation=True, with_norm=True, bias=True, nhid=None):
36+
super().__init__()
37+
n_hid = nin if nhid is None else nhid
38+
self.layers = nn.ModuleList([nn.Linear(nin if i==0 else n_hid,
39+
n_hid if i<nlayer-1 else nout,
40+
bias=True if (i==nlayer-1 and not with_final_activation and bias) # TODO: revise later
41+
or (not with_norm) else False) # set bias=False for BN
42+
for i in range(nlayer)])
43+
self.norms = nn.ModuleList([MaskedBN(n_hid if i<nlayer-1 else nout) if with_norm else Identity()
44+
for i in range(nlayer)])
45+
self.nlayer = nlayer
46+
self.with_final_activation = with_final_activation
47+
self.residual = (nin==nout) ## TODO: test whether need this
48+
49+
def reset_parameters(self):
50+
for layer, norm in zip(self.layers, self.norms):
51+
layer.reset_parameters()
52+
norm.reset_parameters()
53+
54+
def forward(self, x, mask=None):
55+
# x: n x k x d
56+
previous_x = x
57+
for i, (layer, norm) in enumerate(zip(self.layers, self.norms)):
58+
x = layer(x)
59+
if mask is not None:
60+
x[~mask] = 0
61+
if i < self.nlayer-1 or self.with_final_activation:
62+
x = norm(x, mask)
63+
x = F.relu(x)
64+
return x
65+
66+
class MaskedGINConv(nn.Module):
67+
def __init__(self, nin, nout, bias=True, nhid=None):
68+
super().__init__()
69+
self.nn = MaskedMLP(nin, nout, 2, False, bias=bias, nhid=nhid)
70+
self.layer = gnn.GINConv(Identity(), train_eps=True)
71+
def reset_parameters(self):
72+
self.nn.reset_parameters()
73+
self.layer.reset_parameters()
74+
def forward(self, x, edge_index, edge_attr, mask=None):
75+
x = self.layer(x, edge_index)
76+
if mask is not None:
77+
if x[~mask].numel() == 0:
78+
print('~mask numel = 0!!')
79+
print('x shape', x.shape)
80+
print('mask shape', mask.shape)
81+
#assert x[~mask].max() == 0
82+
x = self.nn(x, mask)
83+
# assert x[~mask].max() == 0
84+
return x
85+
86+
87+
class MaskedGINEConv(nn.Module):
88+
def __init__(self, nin, nout, bias=True, nhid=None):
89+
super().__init__()
90+
self.nn = MaskedMLP(nin, nout, 2, False, bias=bias, nhid=nhid)
91+
self.layer = gnn.GINEConv(Identity(), train_eps=True)
92+
def reset_parameters(self):
93+
self.nn.reset_parameters()
94+
self.layer.reset_parameters()
95+
def forward(self, x, edge_index, edge_attr, mask=None):
96+
assert x[~mask].max() == 0
97+
x = self.layer(x, edge_index, edge_attr)
98+
if mask is not None:
99+
x[~mask] = 0
100+
x = self.nn(x, mask)
101+
# assert x[~mask].max() == 0
102+
return x

0 commit comments

Comments
 (0)