-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
98 lines (82 loc) · 3.48 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import tensorflow as tf
from tensorflow import keras
from argparse import ArgumentParser
import gym
from gym.wrappers import FrameStack, AtariPreprocessing
from src.config import config
from src.agents.dqn import DQNAgent
from src.callbacks.wandb import WandbTrainingCallback
def run_training(cfg):
env_orig = gym.make(
"ALE/MsPacman-v5",
frameskip=1,
render_mode=cfg.ENV.RENDER_MODE,
repeat_action_probability=cfg.ENV.STICKY_ACTION_PROB)
env = AtariPreprocessing(env_orig, terminal_on_life_loss=cfg.ENV.TERMINAL_ON_LIFE_LOSS)
env = FrameStack(env, num_stack=4)
tf.keras.backend.set_image_data_format('channels_first')
agent = DQNAgent(
name=cfg.TRAINING.WANDB.GROUP,
use_target=cfg.AGENT.USE_TARGET,
update_target_after=cfg.AGENT.UPDATE_TAGRET_AFTER,
tau=cfg.AGENT.TAU,
double=cfg.AGENT.DOUBLE,
prioritized_replay=cfg.AGENT.PRIORITIZED_REPLAY,
dueling=cfg.AGENT.DUELING,
noisy_nets=cfg.AGENT.NOISY_NETS,
beta=cfg.AGENT.BETA,
augment=cfg.AGENT.AUGMENT,
discount_factor=cfg.AGENT.DISCOUNT_FACTOR,
clip_rewards=cfg.AGENT.CLIP_REWARDS,
record_video=cfg.AGENT.RECORD_VIDEO,
log_table=cfg.AGENT.LOG_TABLE,
log_table_period=cfg.AGENT.LOG_TABLE_PERIOD,
evaluate_after=cfg.AGENT.EVALUATE_AFTER,
evaluation_episodes=cfg.AGENT.EVALUATION_EPISODES,
eps_start=cfg.AGENT.EPS_START,
eps_end=cfg.AGENT.EPS_END,
eps_eval=cfg.AGENT.EPS_EVAL,
annealing_steps=cfg.AGENT.EPS_ANNEALING_STEPS,
buffer_size=cfg.AGENT.BUFFER_SIZE,
max_train_frames=cfg.AGENT.MAX_TRAIN_FRAMES,
max_episode_frames=cfg.AGENT.MAX_EPISODE_FRAMES,
warmup_frames=cfg.AGENT.WARMUP_FRAMES,
batch_size=cfg.AGENT.BATCH_SIZE,
num_updates_per_step=cfg.AGENT.NUM_UPDATES_PER_STEP)
agent.online_network.summary()
loss_fn = keras.losses.Huber if cfg.TRAINING.LOSS == 'huber' else keras.losses.MeanSquaredError
agent.compile(
environment=env,
optimizer=keras.optimizers.Adam(
learning_rate=cfg.TRAINING.LEARNING_RATE,
clipnorm=cfg.TRAINING.CLIPNORM),
loss=loss_fn(reduction=tf.keras.losses.Reduction.NONE))
wandb_config = cfg.clone()
del wandb_config['TRAINING']['WANDB']
del wandb_config['EVALUATION']
wandb_callback = WandbTrainingCallback(
cfg.TRAINING.WANDB.PROJECT,
cfg.TRAINING.WANDB.NAME,
cfg.TRAINING.WANDB.GROUP,
config=wandb_config)
agent.fit(callbacks=[wandb_callback], seed=cfg.TRAINING.SEED)
agent.save_checkpoint(cfg.TRAINING.CHECKPOINT_PATH)
env.close()
env_orig.close()
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('config_path', help='Path to the agent\'s configuration file')
args = parser.parse_args()
cfg = config.get_cfg_defaults()
cfg.merge_from_file(args.config_path)
cfg.freeze()
for seed in cfg.TRAINING.SEEDS:
new_cfg = cfg.clone()
old_name = cfg.TRAINING.CHECKPOINT_PATH.split('/')[1]
new_name = old_name + f'-{seed}'
new_cfg['TRAINING']['SEED'] = seed
new_cfg['TRAINING']['CHECKPOINT_PATH'] = new_cfg['TRAINING']['CHECKPOINT_PATH'].replace(old_name, new_name)
new_cfg['EVALUATION']['EVALUATION_PATH'] = new_cfg['TRAINING']['CHECKPOINT_PATH'].replace(old_name, new_name)
new_cfg['TRAINING']['WANDB']['NAME'] += str(seed)
run_training(new_cfg)