10
10
View more on my tutorial website: https://morvanzhou.github.io/tutorials
11
11
12
12
Dependencies:
13
- tensorflow r1.3
13
+ tensorflow 1.8.0
14
14
gym 0.9.2
15
15
"""
16
16
17
17
import tensorflow as tf
18
18
import numpy as np
19
19
import matplotlib .pyplot as plt
20
20
import gym , threading , queue
21
- import time
22
21
23
22
EP_MAX = 1000
24
23
EP_LEN = 500
27
26
A_LR = 0.0001 # learning rate for actor
28
27
C_LR = 0.0001 # learning rate for critic
29
28
MIN_BATCH_SIZE = 64 # minimum batch size for updating PPO
30
- UPDATE_STEP = 10 # loop update operation n-steps
29
+ UPDATE_STEP = 15 # loop update operation n-steps
31
30
EPSILON = 0.2 # for clipping surrogate objective
32
31
GAME = 'CartPole-v0'
33
32
@@ -51,22 +50,18 @@ def __init__(self):
51
50
self .ctrain_op = tf .train .AdamOptimizer (C_LR ).minimize (self .closs )
52
51
53
52
# actor
54
- self .pi , self . pi_params = self ._build_anet ('pi' , trainable = True )
53
+ self .pi , pi_params = self ._build_anet ('pi' , trainable = True )
55
54
oldpi , oldpi_params = self ._build_anet ('oldpi' , trainable = False )
56
55
57
- self .update_oldpi_op = [oldp .assign (p ) for p , oldp in zip (self . pi_params , oldpi_params )]
56
+ self .update_oldpi_op = [oldp .assign (p ) for p , oldp in zip (pi_params , oldpi_params )]
58
57
59
- self .tfa = tf .placeholder (tf .int32 , [None ,], 'action' )
60
-
58
+ self .tfa = tf .placeholder (tf .int32 , [None , ], 'action' )
61
59
self .tfadv = tf .placeholder (tf .float32 , [None , 1 ], 'advantage' )
62
60
63
- #debug
64
- self .val1 = tf .reduce_sum (self .pi * tf .one_hot (self .tfa , A_DIM , dtype = tf .float32 ), axis = 1 , keep_dims = True )
65
- self .val2 = tf .reduce_sum (oldpi * tf .one_hot (self .tfa , A_DIM , dtype = tf .float32 ), axis = 1 , keep_dims = True )
66
- #debug
67
-
68
- ratio = self .val1 / self .val2
69
-
61
+ a_indices = tf .stack ([tf .range (tf .shape (self .tfa )[0 ], dtype = tf .int32 ), self .tfa ], axis = 1 )
62
+ pi_prob = tf .gather_nd (params = self .pi , indices = a_indices ) # shape=(None, )
63
+ oldpi_prob = tf .gather_nd (params = oldpi , indices = a_indices ) # shape=(None, )
64
+ ratio = pi_prob / oldpi_prob
70
65
surr = ratio * self .tfadv # surrogate loss
71
66
72
67
self .aloss = - tf .reduce_mean (tf .minimum ( # clipped surrogate objective
@@ -82,20 +77,10 @@ def update(self):
82
77
if GLOBAL_EP < EP_MAX :
83
78
UPDATE_EVENT .wait () # wait until get batch of data
84
79
self .sess .run (self .update_oldpi_op ) # copy pi to old pi
85
- s , a , r = [],[],[]
86
- for iter in range (QUEUE .qsize ()):
87
- data = QUEUE .get ()
88
- if iter == 0 :
89
- s = data ['bs' ]
90
- a = data ['ba' ]
91
- r = data ['br' ]
92
- else :
93
- s = np .append (s , data ['bs' ], axis = 0 )
94
- a = np .append (a , data ['ba' ], axis = 0 )
95
- r = np .append (r , data ['br' ], axis = 0 )
96
-
80
+ data = [QUEUE .get () for _ in range (QUEUE .qsize ())] # collect data from all workers
81
+ data = np .vstack (data )
82
+ s , a , r = data [:, :S_DIM ], data [:, S_DIM : S_DIM + 1 ].ravel (), data [:, - 1 :]
97
83
adv = self .sess .run (self .advantage , {self .tfs : s , self .tfdc_r : r })
98
-
99
84
# update actor and critic in a update loop
100
85
[self .sess .run (self .atrain_op , {self .tfs : s , self .tfa : a , self .tfadv : adv }) for _ in range (UPDATE_STEP )]
101
86
[self .sess .run (self .ctrain_op , {self .tfs : s , self .tfdc_r : r }) for _ in range (UPDATE_STEP )]
@@ -104,16 +89,14 @@ def update(self):
104
89
ROLLING_EVENT .set () # set roll-out available
105
90
106
91
def _build_anet (self , name , trainable ):
107
- w_init = tf .random_normal_initializer (0. , .1 )
108
-
109
92
with tf .variable_scope (name ):
110
93
l_a = tf .layers .dense (self .tfs , 200 , tf .nn .relu , trainable = trainable )
111
- a_prob = tf .layers .dense (l_a , A_DIM , tf .nn .softmax , trainable = trainable , name = 'ap' )
94
+ a_prob = tf .layers .dense (l_a , A_DIM , tf .nn .softmax , trainable = trainable )
112
95
params = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = name )
113
96
return a_prob , params
114
97
115
98
def choose_action (self , s ): # run by a local
116
- prob_weights = self .sess .run (self .pi , feed_dict = {self .tfs : s [np . newaxis , :]})
99
+ prob_weights = self .sess .run (self .pi , feed_dict = {self .tfs : s [None , :]})
117
100
action = np .random .choice (range (prob_weights .shape [1 ]),
118
101
p = prob_weights .ravel ()) # select action w.r.t the actions prob
119
102
return action
@@ -132,28 +115,26 @@ def __init__(self, wid):
132
115
def work (self ):
133
116
global GLOBAL_EP , GLOBAL_RUNNING_R , GLOBAL_UPDATE_COUNTER
134
117
while not COORD .should_stop ():
135
- s = self .env .reset ()#new episode
118
+ s = self .env .reset ()
136
119
ep_r = 0
137
120
buffer_s , buffer_a , buffer_r = [], [], []
138
121
for t in range (EP_LEN ):
139
122
if not ROLLING_EVENT .is_set (): # while global PPO is updating
140
123
ROLLING_EVENT .wait () # wait until PPO is updated
141
124
buffer_s , buffer_a , buffer_r = [], [], [] # clear history buffer, use new policy to collect data
142
-
143
125
a = self .ppo .choose_action (s )
144
126
s_ , r , done , _ = self .env .step (a )
145
- if done : r = - 5
127
+ if done : r = - 10
146
128
buffer_s .append (s )
147
129
buffer_a .append (a )
148
- buffer_r .append (( r + 8 ) / 8 ) # normalize reward, find to be useful
130
+ buffer_r .append (r - 1 ) # 0 for not down, -11 for down. Reward engineering
149
131
s = s_
150
132
ep_r += r
151
133
152
- GLOBAL_UPDATE_COUNTER += 1 # count to minimum batch size, no need to wait other workers
134
+ GLOBAL_UPDATE_COUNTER += 1 # count to minimum batch size, no need to wait other workers
153
135
if t == EP_LEN - 1 or GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE or done :
154
-
155
136
if done :
156
- v_s_ = 0 #episode ends
137
+ v_s_ = 0 # end of episode
157
138
else :
158
139
v_s_ = self .ppo .get_v (s_ )
159
140
@@ -162,33 +143,25 @@ def work(self):
162
143
v_s_ = r + GAMMA * v_s_
163
144
discounted_r .append (v_s_ )
164
145
discounted_r .reverse ()
165
-
166
- bs , ba , br = np .vstack (buffer_s ), np .array (buffer_a ), np .array (discounted_r )[:, np .newaxis ]
167
-
168
- buffer_s , buffer_a , buffer_r = [], [], []
169
-
170
- q_in = dict ([('bs' , bs ), ('ba' , ba ), ('br' , br )])
171
- # q_in = dict([('bs', list(bs)), ('ba', list(ba)), ('br', list(br))])
172
-
173
- QUEUE .put (q_in )
174
146
147
+ bs , ba , br = np .vstack (buffer_s ), np .vstack (buffer_a ), np .array (discounted_r )[:, None ]
148
+ buffer_s , buffer_a , buffer_r = [], [], []
149
+ QUEUE .put (np .hstack ((bs , ba , br ))) # put data in the queue
175
150
if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE :
176
151
ROLLING_EVENT .clear () # stop collecting data
177
152
UPDATE_EVENT .set () # globalPPO update
178
-
153
+
179
154
if GLOBAL_EP >= EP_MAX : # stop training
180
155
COORD .request_stop ()
181
156
break
182
157
183
- if done :break
158
+ if done : break
184
159
185
160
# record reward changes, plot later
186
161
if len (GLOBAL_RUNNING_R ) == 0 : GLOBAL_RUNNING_R .append (ep_r )
187
162
else : GLOBAL_RUNNING_R .append (GLOBAL_RUNNING_R [- 1 ]* 0.9 + ep_r * 0.1 )
188
163
GLOBAL_EP += 1
189
- print ("EP" , GLOBAL_EP ,'|W%i' % self .wid , '|step %i' % t , '|Ep_r: %.2f' % ep_r ,)
190
- np .save ("Global_return" ,GLOBAL_RUNNING_R )
191
- np .savez ("PI_PARA" ,self .ppo .sess .run (GLOBAL_PPO .pi_params ))
164
+ print ('{0:.1f}%' .format (GLOBAL_EP / EP_MAX * 100 ), '|W%i' % self .wid , '|Ep_r: %.2f' % ep_r ,)
192
165
193
166
194
167
if __name__ == '__main__' :
@@ -197,9 +170,7 @@ def work(self):
197
170
UPDATE_EVENT .clear () # not update now
198
171
ROLLING_EVENT .set () # start to roll out
199
172
workers = [Worker (wid = i ) for i in range (N_WORKER )]
200
-
201
- start = time .time ()
202
-
173
+
203
174
GLOBAL_UPDATE_COUNTER , GLOBAL_EP = 0 , 0
204
175
GLOBAL_RUNNING_R = []
205
176
COORD = tf .train .Coordinator ()
@@ -214,9 +185,6 @@ def work(self):
214
185
threads [- 1 ].start ()
215
186
COORD .join (threads )
216
187
217
- end = time .time ()
218
- print "Total time " , (end - start )
219
-
220
188
# plot reward change and test
221
189
plt .plot (np .arange (len (GLOBAL_RUNNING_R )), GLOBAL_RUNNING_R )
222
190
plt .xlabel ('Episode' ); plt .ylabel ('Moving reward' ); plt .ion (); plt .show ()
0 commit comments