Skip to content

Commit 924ec4a

Browse files
committed
GAE dev and test versions added. Updated README
1 parent 7e02d32 commit 924ec4a

File tree

8 files changed

+341
-6
lines changed

8 files changed

+341
-6
lines changed

13.GAE.ipynb

+320
Large diffs are not rendered by default.

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Relevant Papers:
1515
10. Rainbow with Quantile Regression [[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/10.Quantile-Rainbow.ipynb)
1616
11. Deep Recurrent Q-Learning for Partially Observable MDPs [[Publication]](https://arxiv.org/abs/1507.06527)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/11.DRQN.ipynb)
1717
12. Advantage Actor Critic (A2C) [[Publication1]](https://arxiv.org/abs/1602.01783)[[Publication2]](https://blog.openai.com/baselines-acktr-a2c/)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/12.A2C.ipynb)
18+
13. High-Dimensional Continuous Control Using Generalized Advantage Estimation [[Publication]](https://arxiv.org/abs/1506.02438)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/13.GAE.ipynb)
1819

1920

2021
Requirements:

a2c_devel.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
#a2c control
4242
config.num_agents=16
4343
config.rollout=5
44+
config.USE_GAE = True
45+
config.gae_tau = 0.95
4446

4547
#misc agent variables
4648
config.GAMMA=0.99

agents/A2C.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self, static_policy=False, env=None, config=None):
4949
self.model.train()
5050

5151
self.rollouts = RolloutStorage(self.rollout, self.num_agents,
52-
self.num_feats, self.env.action_space, self.device)
52+
self.num_feats, self.env.action_space, self.device, config.USE_GAE, config.gae_tau)
5353

5454
self.value_losses = []
5555
self.entropy_losses = []

saved_agents/model.dump

0 Bytes
Binary file not shown.

saved_agents/optim.dump

-24 Bytes
Binary file not shown.

utils/RolloutStorage.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
class RolloutStorage(object):
4-
def __init__(self, num_steps, num_processes, obs_shape, action_space, device):
4+
def __init__(self, num_steps, num_processes, obs_shape, action_space, device, USE_GAE=True, gae_tau=0.95):
55
self.observations = torch.zeros(num_steps + 1, num_processes, *obs_shape).to(device)
66
self.rewards = torch.zeros(num_steps, num_processes, 1).to(device)
77
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1).to(device)
@@ -12,6 +12,8 @@ def __init__(self, num_steps, num_processes, obs_shape, action_space, device):
1212

1313
self.num_steps = num_steps
1414
self.step = 0
15+
self.gae = USE_GAE
16+
self.gae_tau = gae_tau
1517

1618
def insert(self, current_obs, action, action_log_prob, value_pred, reward, mask):
1719
self.observations[self.step + 1].copy_(current_obs)
@@ -28,7 +30,15 @@ def after_update(self):
2830
self.masks[0].copy_(self.masks[-1])
2931

3032
def compute_returns(self, next_value, gamma):
31-
self.returns[-1] = next_value
32-
for step in reversed(range(self.rewards.size(0))):
33-
self.returns[step] = self.returns[step + 1] * \
34-
gamma * self.masks[step + 1] + self.rewards[step]
33+
if self.gae:
34+
self.value_preds[-1] = next_value
35+
gae = 0
36+
for step in reversed(range(self.rewards.size(0))):
37+
delta = self.rewards[step] + gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
38+
gae = delta + gamma * self.gae_tau * self.masks[step + 1] * gae
39+
self.returns[step] = gae + self.value_preds[step]
40+
else:
41+
self.returns[-1] = next_value
42+
for step in reversed(range(self.rewards.size(0))):
43+
self.returns[step] = self.returns[step + 1] * \
44+
gamma * self.masks[step + 1] + self.rewards[step]

utils/hyperparameters.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def __init__(self):
1212
self.value_loss_weight = 0.5
1313
self.entropy_loss_weight = 0.001
1414
self.grad_norm_max = 0.5
15+
self.USE_GAE=True
16+
self.gae_tau = 0.95
1517

1618
#algorithm control
1719
self.USE_NOISY_NETS=False

0 commit comments

Comments
 (0)