1
1
from typing import Optional , Iterable
2
- import cv2
3
2
4
3
from dqn_agent import DQNAgent
5
4
from tetris import Tetris
13
12
class AgentConf :
14
13
def __init__ (self ):
15
14
self .n_neurons = [32 , 32 ]
15
+ self .batch_size = 512
16
16
self .activations = ['relu' , 'relu' , 'linear' ]
17
17
self .episodes = 2000
18
- self .epsilon_stop_episode = 1500
18
+ self .epsilon = 1.0
19
+ self .epsilon_min = 0.0
20
+ self .epsilon_stop_episode = 1600
19
21
self .mem_size = 25000
20
22
self .discount = 0.95
21
- self .replay_start_size = 5000
22
- self .batch_size = 1024
23
+ self .replay_start_size = 2000
23
24
self .epochs = 1
24
25
self .render_every = None
25
26
self .train_every = 1
@@ -33,14 +34,19 @@ def dqn(ac: AgentConf):
33
34
34
35
agent = DQNAgent (env .get_state_size (),
35
36
n_neurons = ac .n_neurons , activations = ac .activations ,
36
- epsilon_stop_episode = ac .epsilon_stop_episode , mem_size = ac .mem_size ,
37
- discount = ac .discount , replay_start_size = ac .replay_start_size )
37
+ epsilon = ac .epsilon , epsilon_min = ac .epsilon_min , epsilon_stop_episode = ac . epsilon_stop_episode ,
38
+ mem_size = ac . mem_size , discount = ac .discount , replay_start_size = ac .replay_start_size )
38
39
39
40
timestamp_str = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
40
- log_dir = f'logs/tetris-{ timestamp_str } -nn={ str (ac .n_neurons )} -mem={ ac .mem_size } ' \
41
- f'-bs={ ac .batch_size } -e={ ac .epochs } '
41
+ # conf.mem_size = mem_size
42
+ # conf.epochs = epochs
43
+ # conf.epsilon_stop_episode = epsilon_stop_episode
44
+ # conf.discount = discount
45
+ log_dir = f'logs/tetris-{ timestamp_str } -ms{ ac .mem_size } -e{ ac .epochs } -ese{ ac .epsilon_stop_episode } -d{ ac .discount } '
42
46
log = CustomTensorBoard (log_dir = log_dir )
43
47
48
+ print (f"AGENT_CONF = { log_dir } " )
49
+
44
50
scores = []
45
51
46
52
episodes_wrapped : Iterable [int ] = tqdm (range (ac .episodes ))
@@ -90,12 +96,16 @@ def dqn(ac: AgentConf):
90
96
91
97
92
98
def enumerate_dqn ():
93
- for bs in [256 , 512 , 1024 ]:
94
- for ms in [5000 , 10_000 , 15_000 , 20_000 , 25_000 ]:
95
- agent_conf = AgentConf ()
96
- agent_conf .batch_size = bs
97
- agent_conf .mem_size = ms
98
- dqn (agent_conf )
99
+ for mem_size in [10_000 , 15_000 , 20_000 , 25_000 ]:
100
+ for epochs in [1 , 2 , 3 ]:
101
+ for epsilon_stop_episode in [1600 , 1800 , 2000 ]:
102
+ for discount in [0.95 , 0.97 , 0.99 ]:
103
+ conf = AgentConf ()
104
+ conf .mem_size = mem_size
105
+ conf .epochs = epochs
106
+ conf .epsilon_stop_episode = epsilon_stop_episode
107
+ conf .discount = discount
108
+ dqn (conf )
99
109
100
110
101
111
if __name__ == "__main__" :
0 commit comments