Skip to content

Commit 7df81c8

Browse files
committed
DPDP code
0 parents  commit 7df81c8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+7186
-0
lines changed

.gitignore

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
data/
2+
logs/
3+
results/
4+
plots/
5+
__pycache__/
6+
.idea
7+
*/.ipynb_checkpoints
8+
*.log
9+
*.bak

README.md

+309
Large diffs are not rendered by default.

config.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import json
2+
3+
4+
class Settings(dict):
5+
"""Experiment configuration options.
6+
7+
Wrapper around in-built dict class to access members through the dot operation.
8+
9+
Experiment parameters:
10+
"expt_name": Name/description of experiment, used for logging.
11+
12+
"train_filepath": Training set path
13+
"val_filepath": Validation set path
14+
"test_filepath": Test set path
15+
16+
"num_nodes": Number of nodes in TSP tours
17+
18+
"node_dim": Number of dimensions for each node
19+
"voc_nodes_in": Input node signal vocabulary size
20+
"voc_nodes_out": Output node prediction vocabulary size
21+
"voc_edges_in": Input edge signal vocabulary size
22+
"voc_edges_out": Output edge prediction vocabulary size
23+
24+
"beam_size": Beam size for beamsearch procedure (-1 for disabling beamsearch)
25+
26+
"hidden_dim": Dimension of model's hidden state
27+
"num_layers": Number of GCN layers
28+
"mlp_layers": Number of MLP layers
29+
"aggregation": Node aggregation scheme in GCN (`mean` or `sum`)
30+
31+
"max_epochs": Maximum training epochs
32+
"val_every": Interval (in epochs) at which validation is performed
33+
"test_every": Interval (in epochs) at which testing is performed
34+
35+
"batch_size": Batch size
36+
"batches_per_epoch": Batches per epoch (-1 for using full training set)
37+
"accumulation_steps": Number of steps for gradient accumulation (DO NOT USE: BUGGY)
38+
"num_segments_checkpoint": How many checkpoint chunks to create
39+
40+
"learning_rate": Initial learning rate
41+
"decay_rate": Learning rate decay parameter
42+
"""
43+
44+
def __init__(self, config_dict):
45+
super().__init__()
46+
for key in config_dict:
47+
self[key] = config_dict[key]
48+
49+
def __getattr__(self, attr):
50+
return self[attr]
51+
52+
def __setitem__(self, key, value):
53+
return super().__setitem__(key, value)
54+
55+
def __setattr__(self, key, value):
56+
return self.__setitem__(key, value)
57+
58+
__delattr__ = dict.__delitem__
59+
60+
61+
def get_default_config():
62+
"""Returns default settings object.
63+
"""
64+
return Settings(json.load(open("./configs/default.json")))
65+
66+
67+
def get_config(filepath):
68+
"""Returns settings from json file.
69+
"""
70+
config = get_default_config()
71+
config.update(Settings(json.load(open(filepath))))
72+
return config

configs/default.json

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"expt_name": "default",
3+
4+
"train_filepath": "./data/tsp10_train_concorde.txt",
5+
"val_filepath": "./data/tsp10_val_concorde.txt",
6+
"test_filepath": "./data/tsp10_test_concorde.txt",
7+
8+
"num_nodes": 10,
9+
"num_neighbors": -1,
10+
11+
"node_dim": 2,
12+
"voc_nodes_in": 2,
13+
"voc_nodes_out": 2,
14+
"voc_edges_in": 3,
15+
"voc_edges_out": 2,
16+
17+
"beam_size": 10,
18+
19+
"hidden_dim": 50,
20+
"num_layers": 3,
21+
"mlp_layers": 2,
22+
"aggregation": "mean",
23+
24+
"max_epochs": 10,
25+
"val_every": 5,
26+
"test_every": 10,
27+
28+
"batch_size": 20,
29+
"batches_per_epoch": 500,
30+
"accumulation_steps": 1,
31+
32+
"learning_rate": 0.001,
33+
"decay_rate": 1.01
34+
}

configs/vrp_nazari100.json

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
{
2+
"expt_name": "vrp_nazari100",
3+
4+
"train_filepath": "data/vrp/vrp_nazari100_train_seed42.pkl",
5+
"train_filepath_solution": "results/vrp/vrp_nazari100_train_seed42/vrp_nazari100_train_seed42-lkh.pkl",
6+
7+
"val_filepath": "data/vrp/vrp_nazari100_validation_seed4321.pkl",
8+
"val_filepath_solution": "results/vrp/vrp_nazari100_validation_seed4321/vrp_nazari100_validation_seed4321-lkh.pkl",
9+
10+
"test_filepath": "data/vrp/vrp_nazari100_test_seed1234.pkl",
11+
"test_filepath_solution": "results/vrp/vrp_nazari100_test_seed1234/vrp_nazari100_test_seed1234-lkh.pkl",
12+
13+
"num_nodes": 100,
14+
"num_neighbors": -1,
15+
16+
"node_dim": 3,
17+
"voc_nodes_in": 2,
18+
"voc_nodes_out": 2,
19+
"voc_edges_in": 6,
20+
"voc_edges_out": 2,
21+
22+
"beam_size": 1,
23+
24+
"hidden_dim": 300,
25+
"num_layers": 30,
26+
"mlp_layers": 3,
27+
"aggregation": "mean",
28+
29+
"max_epochs": 1500,
30+
"val_every": 5,
31+
"test_every": 100,
32+
33+
"batch_size": 48,
34+
"batches_per_epoch": 500,
35+
"accumulation_steps": 1,
36+
37+
"learning_rate": 0.001,
38+
"decay_rate": 1.01
39+
}

configs/vrp_uchoa100.json

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
{
2+
"expt_name": "vrp_uchoa100",
3+
4+
"train_filepath": "data/vrp/vrp_uchoa100_train_seed42.pkl",
5+
"train_filepath_solution": "results/vrp/vrp_uchoa100_train_seed42/vrp_uchoa100_train_seed42n250000-lkh.pkl",
6+
7+
"val_filepath": "data/vrp/vrp_uchoa100_validation_seed4321.pkl",
8+
"val_filepath_solution": "results/vrp/vrp_uchoa100_validation_seed4321/vrp_uchoa100_validation_seed4321-lkh.pkl",
9+
10+
"test_filepath": "data/vrp/vrp_uchoa100_test_seed1234.pkl",
11+
"test_filepath_solution": "results/vrp/vrp_uchoa100_test_seed1234/vrp_uchoa100_test_seed1234-lkh.pkl",
12+
13+
"num_nodes": 100,
14+
"num_neighbors": -1,
15+
16+
"node_dim": 3,
17+
"voc_nodes_in": 2,
18+
"voc_nodes_out": 2,
19+
"voc_edges_in": 6,
20+
"voc_edges_out": 2,
21+
22+
"beam_size": 1,
23+
24+
"hidden_dim": 300,
25+
"num_layers": 30,
26+
"mlp_layers": 3,
27+
"aggregation": "mean",
28+
29+
"max_epochs": 1500,
30+
"val_every": 5,
31+
"test_every": 100,
32+
33+
"batch_size": 48,
34+
"batches_per_epoch": 500,
35+
"accumulation_steps": 1,
36+
37+
"learning_rate": 0.001,
38+
"decay_rate": 1.01
39+
}

dp/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .graph import Graph, MergedGraph, BatchGraph
2+
from .topk import StreamingTopK, SimpleBatchTopK
3+
from .dp import run_dp

0 commit comments

Comments
 (0)