Skip to content

Commit 601990e

Browse files
authored
Merge pull request #69 from ruihuili/master
DPPO with discrete action space
2 parents 7334763 + b66fb13 commit 601990e

File tree

1 file changed

+231
-0
lines changed

1 file changed

+231
-0
lines changed
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
"""
2+
A simple version of OpenAI's Proximal Policy Optimization (PPO). [https://arxiv.org/abs/1707.06347]
3+
4+
Distributing workers in parallel to collect data, then stop worker's roll-out and train PPO on collected data.
5+
Restart workers once PPO is updated.
6+
7+
The global PPO updating rule is adopted from DeepMind's paper (DPPO):
8+
Emergence of Locomotion Behaviours in Rich Environments (Google Deepmind): [https://arxiv.org/abs/1707.02286]
9+
10+
View more on my tutorial website: https://morvanzhou.github.io/tutorials
11+
12+
Dependencies:
13+
tensorflow r1.3
14+
gym 0.9.2
15+
"""
16+
17+
import tensorflow as tf
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
import gym, threading, queue
21+
import time
22+
23+
EP_MAX = 1000
24+
EP_LEN = 500
25+
N_WORKER = 4 # parallel workers
26+
GAMMA = 0.9 # reward discount factor
27+
A_LR = 0.0001 # learning rate for actor
28+
C_LR = 0.0001 # learning rate for critic
29+
MIN_BATCH_SIZE = 64 # minimum batch size for updating PPO
30+
UPDATE_STEP = 10 # loop update operation n-steps
31+
EPSILON = 0.2 # for clipping surrogate objective
32+
GAME = 'CartPole-v0'
33+
34+
env = gym.make(GAME)
35+
S_DIM = env.observation_space.shape[0]
36+
A_DIM = env.action_space.n
37+
38+
39+
class PPONet(object):
40+
def __init__(self):
41+
self.sess = tf.Session()
42+
self.tfs = tf.placeholder(tf.float32, [None, S_DIM], 'state')
43+
44+
# critic
45+
w_init = tf.random_normal_initializer(0., .1)
46+
lc = tf.layers.dense(self.tfs, 200, tf.nn.relu, kernel_initializer=w_init, name='lc')
47+
self.v = tf.layers.dense(lc, 1)
48+
self.tfdc_r = tf.placeholder(tf.float32, [None, 1], 'discounted_r')
49+
self.advantage = self.tfdc_r - self.v
50+
self.closs = tf.reduce_mean(tf.square(self.advantage))
51+
self.ctrain_op = tf.train.AdamOptimizer(C_LR).minimize(self.closs)
52+
53+
# actor
54+
self.pi, self.pi_params = self._build_anet('pi', trainable=True)
55+
oldpi, oldpi_params = self._build_anet('oldpi', trainable=False)
56+
57+
self.update_oldpi_op = [oldp.assign(p) for p, oldp in zip(self.pi_params, oldpi_params)]
58+
59+
self.tfa = tf.placeholder(tf.int32, [None,], 'action')
60+
61+
self.tfadv = tf.placeholder(tf.float32, [None, 1], 'advantage')
62+
63+
#debug
64+
self.val1 = tf.reduce_sum(self.pi * tf.one_hot(self.tfa, A_DIM, dtype=tf.float32), axis=1, keep_dims=True)
65+
self.val2 = tf.reduce_sum(oldpi * tf.one_hot(self.tfa, A_DIM, dtype=tf.float32), axis=1, keep_dims=True)
66+
#debug
67+
68+
ratio = self.val1/self.val2
69+
70+
surr = ratio * self.tfadv # surrogate loss
71+
72+
self.aloss = -tf.reduce_mean(tf.minimum( # clipped surrogate objective
73+
surr,
74+
tf.clip_by_value(ratio, 1. - EPSILON, 1. + EPSILON) * self.tfadv))
75+
76+
self.atrain_op = tf.train.AdamOptimizer(A_LR).minimize(self.aloss)
77+
self.sess.run(tf.global_variables_initializer())
78+
79+
def update(self):
80+
global GLOBAL_UPDATE_COUNTER
81+
while not COORD.should_stop():
82+
if GLOBAL_EP < EP_MAX:
83+
UPDATE_EVENT.wait() # wait until get batch of data
84+
self.sess.run(self.update_oldpi_op) # copy pi to old pi
85+
s, a, r = [],[],[]
86+
for iter in range(QUEUE.qsize()):
87+
data = QUEUE.get()
88+
if iter == 0:
89+
s = data['bs']
90+
a = data['ba']
91+
r = data['br']
92+
else:
93+
s = np.append(s, data['bs'], axis=0)
94+
a = np.append(a, data['ba'], axis=0)
95+
r = np.append(r, data['br'], axis=0)
96+
97+
adv = self.sess.run(self.advantage, {self.tfs: s, self.tfdc_r: r})
98+
99+
# update actor and critic in a update loop
100+
[self.sess.run(self.atrain_op, {self.tfs: s, self.tfa: a, self.tfadv: adv}) for _ in range(UPDATE_STEP)]
101+
[self.sess.run(self.ctrain_op, {self.tfs: s, self.tfdc_r: r}) for _ in range(UPDATE_STEP)]
102+
UPDATE_EVENT.clear() # updating finished
103+
GLOBAL_UPDATE_COUNTER = 0 # reset counter
104+
ROLLING_EVENT.set() # set roll-out available
105+
106+
def _build_anet(self, name, trainable):
107+
w_init = tf.random_normal_initializer(0., .1)
108+
109+
with tf.variable_scope(name):
110+
l_a = tf.layers.dense(self.tfs, 200, tf.nn.relu, trainable=trainable)
111+
a_prob = tf.layers.dense(l_a, A_DIM, tf.nn.softmax, trainable=trainable, name='ap')
112+
params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
113+
return a_prob, params
114+
115+
def choose_action(self, s): # run by a local
116+
prob_weights = self.sess.run(self.pi, feed_dict={self.tfs: s[np.newaxis, :]})
117+
action = np.random.choice(range(prob_weights.shape[1]),
118+
p=prob_weights.ravel()) # select action w.r.t the actions prob
119+
return action
120+
121+
def get_v(self, s):
122+
if s.ndim < 2: s = s[np.newaxis, :]
123+
return self.sess.run(self.v, {self.tfs: s})[0, 0]
124+
125+
126+
class Worker(object):
127+
def __init__(self, wid):
128+
self.wid = wid
129+
self.env = gym.make(GAME).unwrapped
130+
self.ppo = GLOBAL_PPO
131+
132+
def work(self):
133+
global GLOBAL_EP, GLOBAL_RUNNING_R, GLOBAL_UPDATE_COUNTER
134+
while not COORD.should_stop():
135+
s = self.env.reset()#new episode
136+
ep_r = 0
137+
buffer_s, buffer_a, buffer_r = [], [], []
138+
for t in range(EP_LEN):
139+
if not ROLLING_EVENT.is_set(): # while global PPO is updating
140+
ROLLING_EVENT.wait() # wait until PPO is updated
141+
buffer_s, buffer_a, buffer_r = [], [], [] # clear history buffer, use new policy to collect data
142+
143+
a = self.ppo.choose_action(s)
144+
s_, r, done, _ = self.env.step(a)
145+
if done: r = -5
146+
buffer_s.append(s)
147+
buffer_a.append(a)
148+
buffer_r.append((r + 8) / 8) # normalize reward, find to be useful
149+
s = s_
150+
ep_r += r
151+
152+
GLOBAL_UPDATE_COUNTER += 1 # count to minimum batch size, no need to wait other workers
153+
if t == EP_LEN - 1 or GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE or done:
154+
155+
if done:
156+
v_s_ = 0 #episode ends
157+
else:
158+
v_s_ = self.ppo.get_v(s_)
159+
160+
discounted_r = [] # compute discounted reward
161+
for r in buffer_r[::-1]:
162+
v_s_ = r + GAMMA * v_s_
163+
discounted_r.append(v_s_)
164+
discounted_r.reverse()
165+
166+
bs, ba, br = np.vstack(buffer_s), np.array(buffer_a), np.array(discounted_r)[:, np.newaxis]
167+
168+
buffer_s, buffer_a, buffer_r = [], [], []
169+
170+
q_in = dict([('bs', bs), ('ba', ba), ('br', br)])
171+
# q_in = dict([('bs', list(bs)), ('ba', list(ba)), ('br', list(br))])
172+
173+
QUEUE.put(q_in)
174+
175+
if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE:
176+
ROLLING_EVENT.clear() # stop collecting data
177+
UPDATE_EVENT.set() # globalPPO update
178+
179+
if GLOBAL_EP >= EP_MAX: # stop training
180+
COORD.request_stop()
181+
break
182+
183+
if done:break
184+
185+
# record reward changes, plot later
186+
if len(GLOBAL_RUNNING_R) == 0: GLOBAL_RUNNING_R.append(ep_r)
187+
else: GLOBAL_RUNNING_R.append(GLOBAL_RUNNING_R[-1]*0.9+ep_r*0.1)
188+
GLOBAL_EP += 1
189+
print("EP", GLOBAL_EP,'|W%i' % self.wid, '|step %i' %t, '|Ep_r: %.2f' % ep_r,)
190+
np.save("Global_return",GLOBAL_RUNNING_R)
191+
np.savez("PI_PARA",self.ppo.sess.run(GLOBAL_PPO.pi_params))
192+
193+
194+
if __name__ == '__main__':
195+
GLOBAL_PPO = PPONet()
196+
UPDATE_EVENT, ROLLING_EVENT = threading.Event(), threading.Event()
197+
UPDATE_EVENT.clear() # not update now
198+
ROLLING_EVENT.set() # start to roll out
199+
workers = [Worker(wid=i) for i in range(N_WORKER)]
200+
201+
start = time.time()
202+
203+
GLOBAL_UPDATE_COUNTER, GLOBAL_EP = 0, 0
204+
GLOBAL_RUNNING_R = []
205+
COORD = tf.train.Coordinator()
206+
QUEUE = queue.Queue() # workers putting data in this queue
207+
threads = []
208+
for worker in workers: # worker threads
209+
t = threading.Thread(target=worker.work, args=())
210+
t.start() # training
211+
threads.append(t)
212+
# add a PPO updating thread
213+
threads.append(threading.Thread(target=GLOBAL_PPO.update,))
214+
threads[-1].start()
215+
COORD.join(threads)
216+
217+
end = time.time()
218+
print "Total time ", (end - start)
219+
220+
# plot reward change and test
221+
plt.plot(np.arange(len(GLOBAL_RUNNING_R)), GLOBAL_RUNNING_R)
222+
plt.xlabel('Episode'); plt.ylabel('Moving reward'); plt.ion(); plt.show()
223+
env = gym.make('CartPole-v0')
224+
while True:
225+
s = env.reset()
226+
for t in range(1000):
227+
env.render()
228+
s, r, done, info = env.step(GLOBAL_PPO.choose_action(s))
229+
if done:
230+
break
231+

0 commit comments

Comments
 (0)