Skip to content

Commit 982e544

Browse files
Create DuelingDQN.py
Dueling DQN
1 parent 2ca895a commit 982e544

File tree

1 file changed

+279
-0
lines changed

1 file changed

+279
-0
lines changed

DQN/DuelingDQN.py

+279
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import gym
2+
import os
3+
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
4+
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
5+
import tensorflow as tf
6+
import numpy as np
7+
import random
8+
from collections import deque
9+
from environment import Environment
10+
from tensorboardX import SummaryWriter
11+
12+
# Hyper Parameters for DQN
13+
GAMMA = 0.9 # discount factor for target Q
14+
INITIAL_EPSILON = 0.5 # starting value of epsilon
15+
FINAL_EPSILON = 0.01 # final value of epsilon
16+
REPLAY_SIZE = 50000 # experience replay buffer size
17+
BATCH_SIZE = 32 # size of minibatch
18+
REPLACE_TARGET_FREQ = 10 # frequency to update target Q network
19+
DOUBLE_DQN = False
20+
DUELING_DQN = True
21+
22+
23+
class DQN():
24+
# DQN Agent
25+
def __init__(self, env):
26+
# init experience replay
27+
self.replay_buffer = deque()
28+
# init some parameters
29+
self.time_step = 0
30+
self.epsilon = INITIAL_EPSILON
31+
self.state_dim = env.observation_space.shape[0]
32+
self.action_dim = env.action_space.n
33+
34+
self.create_Q_network()
35+
self.create_training_method()
36+
37+
# Init session
38+
self.session = tf.InteractiveSession()
39+
self.session.run(tf.global_variables_initializer())
40+
41+
def create_Q_network(self):
42+
"""
43+
Q net 网络定义
44+
:return:
45+
"""
46+
# 输入状态 placeholder
47+
self.state_input = tf.placeholder("float", [None, self.state_dim])
48+
49+
# Q 网络结构 两层全连接
50+
with tf.variable_scope('current_net'):
51+
W1 = self.weight_variable([self.state_dim, 100])
52+
b1 = self.bias_variable([100])
53+
h_layer = tf.nn.tanh(tf.matmul(self.state_input, W1) + b1)
54+
55+
if DUELING_DQN:
56+
with tf.variable_scope('current_net_value'):
57+
W2 = self.weight_variable([100, 1])
58+
b2 = self.bias_variable([1])
59+
self.V = tf.matmul(h_layer, W2) + b2
60+
61+
with tf.variable_scope('current_net_advantage'):
62+
W2 = self.weight_variable([100, self.action_dim])
63+
b2 = self.bias_variable([self.action_dim])
64+
self.A = tf.matmul(h_layer, W2) + b2
65+
66+
with tf.variable_scope('Q'):
67+
# Q Value # 合并 V 和 A, 为了不让 A 直接学成了 Q, 我们减掉了 A 的均值
68+
self.Q_value = self.V + (self.A - tf.reduce_mean(self.A, axis=1, keep_dims=True)) # Q = V(s) + A(s,a)
69+
else:
70+
with tf.variable_scope('Q'):
71+
W2 = self.weight_variable([100, self.action_dim])
72+
b2 = self.bias_variable([self.action_dim])
73+
h_layer = tf.nn.tanh(tf.matmul(self.state_input, W1) + b1)
74+
# Q Value
75+
self.Q_value = tf.matmul(h_layer, W2) + b2
76+
77+
# Target Net 结构与 Q相同,可以用tf的reuse实现
78+
with tf.variable_scope('target_net'):
79+
W1t = self.weight_variable([self.state_dim, 100])
80+
b1t = self.bias_variable([100])
81+
h_layer = tf.nn.tanh(tf.matmul(self.state_input, W1t) + b1t)
82+
83+
if DUELING_DQN:
84+
with tf.variable_scope('target_net_value'):
85+
W2t = self.weight_variable([100, 1])
86+
b2t = self.bias_variable([1])
87+
self.Vt = tf.matmul(h_layer, W2t) + b2t
88+
89+
with tf.variable_scope('target_net_advantage'):
90+
W2t = self.weight_variable([100, self.action_dim])
91+
b2t = self.bias_variable([self.action_dim])
92+
self.At = tf.matmul(h_layer, W2t) + b2t
93+
94+
with tf.variable_scope('target_Q'):
95+
# Q Value
96+
self.target_Q_value = self.Vt + (
97+
self.At - tf.reduce_mean(self.At, axis=1, keep_dims=True)) # Q = V(s) + A(s,a)
98+
else:
99+
with tf.variable_scope('target_Q'):
100+
W2t = self.weight_variable([100, self.action_dim])
101+
b2t = self.bias_variable([self.action_dim])
102+
h_layer_t = tf.nn.tanh(tf.matmul(self.state_input, W1t) + b1t)
103+
# target Q Value
104+
self.target_Q_value = tf.matmul(h_layer_t, W2t) + b2t
105+
106+
t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
107+
e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='current_net')
108+
109+
# soft update 更新 target net
110+
with tf.variable_scope('soft_replacement'):
111+
self.target_replace_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
112+
113+
def weight_variable(self, shape):
114+
"""
115+
初始化网络权值(随机, truncated_normal)
116+
:param shape:
117+
:return:
118+
"""
119+
initial = tf.truncated_normal(shape)
120+
return tf.Variable(initial)
121+
122+
def bias_variable(self, shape):
123+
"""
124+
初始化bias(const)
125+
:param shape:
126+
:return:
127+
"""
128+
initial = tf.constant(0.01, shape=shape)
129+
return tf.Variable(initial)
130+
131+
def create_training_method(self):
132+
self.action_input = tf.placeholder("float", [None, self.action_dim]) # one hot presentation
133+
self.y_input = tf.placeholder("float", [None])
134+
Q_action = tf.reduce_sum(tf.multiply(self.Q_value, self.action_input), reduction_indices=1)
135+
self.cost = tf.reduce_mean(tf.square(self.y_input - Q_action))
136+
self.optimizer = tf.train.AdamOptimizer(0.0001).minimize(self.cost)
137+
138+
def perceive(self, state, action, reward, next_state, done):
139+
"""
140+
Replay buffer
141+
:param state:
142+
:param action:
143+
:param reward:
144+
:param next_state:
145+
:param done:
146+
:return:
147+
"""
148+
# 对action 进行one-hot存储,方便网络进行处理
149+
# [0,0,0,0,1,0,0,0,0] action=5
150+
one_hot_action = np.zeros(self.action_dim)
151+
one_hot_action[action] = 1
152+
153+
# 存入replay_buffer
154+
# self.replay_buffer = deque()
155+
self.replay_buffer.append((state, one_hot_action, reward, next_state, done))
156+
157+
# 溢出出队
158+
if len(self.replay_buffer) > REPLAY_SIZE:
159+
self.replay_buffer.popleft()
160+
161+
# 可进行训练条件
162+
if len(self.replay_buffer) > BATCH_SIZE:
163+
self.train_Q_network()
164+
165+
def train_Q_network(self):
166+
"""
167+
Q网络训练
168+
:return:
169+
"""
170+
self.time_step += 1
171+
# 从 replay buffer D中随机选取 batch size N条数据<s_j,a_j,r_j,s_j+1,done>$ D_selected
172+
minibatch = random.sample(self.replay_buffer, BATCH_SIZE)
173+
state_batch = [data[0] for data in minibatch]
174+
action_batch = [data[1] for data in minibatch]
175+
reward_batch = [data[2] for data in minibatch]
176+
next_state_batch = [data[3] for data in minibatch]
177+
178+
# 计算目标Q值y
179+
y_batch = []
180+
QTarget_value_batch = self.target_Q_value.eval(feed_dict={self.state_input: next_state_batch})
181+
Q_value_batch = self.Q_value.eval(feed_dict={self.state_input: next_state_batch})
182+
for i in range(0, BATCH_SIZE):
183+
done = minibatch[i][4]
184+
if done:
185+
y_batch.append(reward_batch[i])
186+
else:
187+
#################用target Q(Q)#######################
188+
if DOUBLE_DQN:
189+
selected_q_next = QTarget_value_batch[i][np.argmax(Q_value_batch[i])]
190+
#################用target Q(target Q)################
191+
else:
192+
selected_q_next = np.max(QTarget_value_batch[i])
193+
194+
y_batch.append(reward_batch[i] + GAMMA * selected_q_next)
195+
196+
self.optimizer.run(feed_dict={
197+
self.y_input: y_batch,
198+
self.action_input: action_batch,
199+
self.state_input: state_batch
200+
})
201+
202+
def egreedy_action(self, state):
203+
"""
204+
epsilon-greedy策略
205+
:param state:
206+
:return:
207+
"""
208+
Q_value = self.Q_value.eval(feed_dict={
209+
self.state_input: [state]
210+
})[0]
211+
if random.random() <= self.epsilon:
212+
self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / 10000
213+
return random.randint(0, self.action_dim - 1)
214+
else:
215+
self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / 10000
216+
return np.argmax(Q_value)
217+
218+
def action(self, state):
219+
return np.argmax(self.Q_value.eval(feed_dict={
220+
self.state_input: [state]
221+
})[0])
222+
223+
def update_target_q_network(self, episode):
224+
# update target Q netowrk
225+
if episode % REPLACE_TARGET_FREQ == 0:
226+
self.session.run(self.target_replace_op)
227+
# print('episode '+str(episode) +', target Q network params replaced!')
228+
229+
230+
# ---------------------------------------------------------
231+
# Hyper Parameters
232+
ENV_NAME = 'CartPole-v0'
233+
EPISODE = 1000 # Episode limitation
234+
235+
236+
def main():
237+
# initialize OpenAI Gym env and dqn agent
238+
env = Environment()
239+
env = gym.make(ENV_NAME)
240+
241+
agent = DQN(env)
242+
writer = SummaryWriter()
243+
244+
# with writer:
245+
# writer.add_graph(net, (input_data,))
246+
score = []
247+
mean = []
248+
for episode in range(EPISODE):
249+
# initialize task
250+
state = env.reset()
251+
total_reward = 0
252+
253+
step = 0
254+
# Train
255+
# for step in range(STEP):
256+
while True:
257+
action = agent.egreedy_action(state) # e-greedy action for train
258+
next_state, reward, done, _ = env.step(action)
259+
260+
agent.perceive(state, action, reward, next_state, done)
261+
state = next_state
262+
if done:
263+
total_reward = total_reward
264+
total_reward = total_reward + reward
265+
break
266+
total_reward += reward
267+
step += 1
268+
269+
print(total_reward)
270+
score.append(total_reward)
271+
writer.add_scalar('total_reward', total_reward, episode)
272+
mean_reward = sum(score[-100:]) / 100
273+
mean.append(mean_reward)
274+
writer.add_scalar('mean_reward', mean_reward, episode)
275+
agent.update_target_q_network(episode)
276+
277+
278+
if __name__ == '__main__':
279+
main()

0 commit comments

Comments
 (0)