5566@Date: 2020-06-11 20:58:21
77@LastEditor: John
8- LastEditTime: 2021-05-04 14:49:45
8+ LastEditTime: 2021-09-16 01:31:33
99@Discription:
1010@Environment: python 3.7.7
1111'''
1212import sys ,os
13- curr_path = os .path .dirname (__file__ )
14- parent_path = os .path .dirname (curr_path )
15- sys .path .append (parent_path ) # add current terminal path to sys .path
13+ curr_path = os .path .dirname (os . path . abspath ( __file__ )) # 当前文件所在绝对路径
14+ parent_path = os .path .dirname (curr_path ) # 父路径
15+ sys .path .append (parent_path ) # 添加父路径到系统路径sys .path
1616
1717import datetime
1818import gym
2121from DDPG .env import NormalizedActions , OUNoise
2222from DDPG .agent import DDPG
2323from common .utils import save_results ,make_dir
24- from common .plot import plot_rewards
25-
26- curr_time = datetime .datetime .now ().strftime (
27- "%Y%m%d-%H%M%S" ) # obtain current time
24+ from common .plot import plot_rewards , plot_rewards_cn
2825
26+ curr_time = datetime .datetime .now ().strftime ("%Y%m%d-%H%M%S" ) # 获取当前时间
2927
3028class DDPGConfig :
3129 def __init__ (self ):
32- self .algo = 'DDPG'
33- self .env = 'Pendulum-v0' # env name
30+ self .algo = 'DDPG' # 算法名称
31+ self .env = 'Pendulum-v0' # 环境名称
3432 self .result_path = curr_path + "/outputs/" + self .env + \
35- '/' + curr_time + '/results/' # path to save results
33+ '/' + curr_time + '/results/' # 保存结果的路径
3634 self .model_path = curr_path + "/outputs/" + self .env + \
37- '/' + curr_time + '/models/' # path to save results
38- self .gamma = 0.99
39- self .critic_lr = 1e-3
40- self .actor_lr = 1e-4
41- self .memory_capacity = 10000
35+ '/' + curr_time + '/models/' # 保存模型的路径
36+ self .train_eps = 300 # 训练的回合数
37+ self .eval_eps = 50 # 测试的回合数
38+ self .gamma = 0.99 # 折扣因子
39+ self .critic_lr = 1e-3 # 评论家网络的学习率
40+ self .actor_lr = 1e-4 # 演员网络的学习率
41+ self .memory_capacity = 8000
4242 self .batch_size = 128
43- self .train_eps = 300
44- self .eval_eps = 50
45- self .eval_steps = 200
46- self .target_update = 4
47- self .hidden_dim = 30
48- self .soft_tau = 1e-2
49- self .device = torch .device (
50- "cuda" if torch .cuda .is_available () else "cpu" )
43+ self .target_update = 2
44+ self .hidden_dim = 256
45+ self .soft_tau = 1e-2 # 软更新参数
46+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
5147
5248def env_agent_config (cfg ,seed = 1 ):
5349 env = NormalizedActions (gym .make (cfg .env ))
54- env .seed (seed )
50+ env .seed (seed ) # 随机种子
5551 state_dim = env .observation_space .shape [0 ]
5652 action_dim = env .action_space .shape [0 ]
5753 agent = DDPG (state_dim ,action_dim ,cfg )
5854 return env ,agent
5955
6056def train (cfg , env , agent ):
61- print ('Start to train ! ' )
62- print (f'Env: { cfg .env } , Algorithm: { cfg .algo } , Device: { cfg .device } ' )
63- ou_noise = OUNoise (env .action_space ) # action noise
64- rewards = []
65- ma_rewards = [] # moving average rewards
66- for i_episode in range (cfg .train_eps ):
57+ print ('开始训练! ' )
58+ print (f'环境: { cfg .env } ,算法: { cfg .algo } ,设备: { cfg .device } ' )
59+ ou_noise = OUNoise (env .action_space ) # 动作噪声
60+ rewards = [] # 记录奖励
61+ ma_rewards = [] # 记录滑动平均奖励
62+ for i_ep in range (cfg .train_eps ):
6763 state = env .reset ()
6864 ou_noise .reset ()
6965 done = False
@@ -72,29 +68,29 @@ def train(cfg, env, agent):
7268 while not done :
7369 i_step += 1
7470 action = agent .choose_action (state )
75- action = ou_noise .get_action (
76- action , i_step ) # 即paper中的random process
71+ action = ou_noise .get_action (action , i_step )
7772 next_state , reward , done , _ = env .step (action )
7873 ep_reward += reward
7974 agent .memory .push (state , action , reward , next_state , done )
8075 agent .update ()
8176 state = next_state
82- print ('Episode:{}/{}, Reward:{}' .format (i_episode + 1 , cfg .train_eps , ep_reward ))
77+ if (i_ep + 1 )% 10 == 0 :
78+ print ('回合:{}/{},奖励:{:.2f}' .format (i_ep + 1 , cfg .train_eps , ep_reward ))
8379 rewards .append (ep_reward )
8480 if ma_rewards :
8581 ma_rewards .append (0.9 * ma_rewards [- 1 ]+ 0.1 * ep_reward )
8682 else :
8783 ma_rewards .append (ep_reward )
88- print ('Complete training !' )
84+ print ('完成训练 !' )
8985 return rewards , ma_rewards
9086
9187def eval (cfg , env , agent ):
92- print ('Start to Eval ! ' )
93- print (f'Env: { cfg .env } , Algorithm: { cfg .algo } , Device: { cfg .device } ' )
94- rewards = []
95- ma_rewards = [] # moving average rewards
96- for i_episode in range (cfg .eval_eps ):
97- state = env .reset ()
88+ print ('开始测试! ' )
89+ print (f'环境: { cfg .env } , 算法: { cfg .algo } , 设备: { cfg .device } ' )
90+ rewards = [] # 记录奖励
91+ ma_rewards = [] # 记录滑动平均奖励
92+ for i_ep in range (cfg .eval_eps ):
93+ state = env .reset ()
9894 done = False
9995 ep_reward = 0
10096 i_step = 0
@@ -104,32 +100,29 @@ def eval(cfg, env, agent):
104100 next_state , reward , done , _ = env .step (action )
105101 ep_reward += reward
106102 state = next_state
107- print ('Episode: {}/{}, Reward: {}' .format (i_episode + 1 , cfg .train_eps , ep_reward ))
103+ print ('回合: {}/{}, 奖励: {}' .format (i_ep + 1 , cfg .train_eps , ep_reward ))
108104 rewards .append (ep_reward )
109105 if ma_rewards :
110106 ma_rewards .append (0.9 * ma_rewards [- 1 ]+ 0.1 * ep_reward )
111107 else :
112108 ma_rewards .append (ep_reward )
113- print ('Complete Eval !' )
109+ print ('完成测试 !' )
114110 return rewards , ma_rewards
115111
116112
117113if __name__ == "__main__" :
118114 cfg = DDPGConfig ()
119-
120- # train
115+ # 训练
121116 env ,agent = env_agent_config (cfg ,seed = 1 )
122117 rewards , ma_rewards = train (cfg , env , agent )
123118 make_dir (cfg .result_path , cfg .model_path )
124119 agent .save (path = cfg .model_path )
125120 save_results (rewards , ma_rewards , tag = 'train' , path = cfg .result_path )
126- plot_rewards (rewards , ma_rewards , tag = "train" ,
127- algo = cfg .algo , path = cfg .result_path )
128-
129- # eval
121+ plot_rewards_cn (rewards , ma_rewards , tag = "train" , env = cfg .env , algo = cfg .algo , path = cfg .result_path )
122+ # 测试
130123 env ,agent = env_agent_config (cfg ,seed = 10 )
131124 agent .load (path = cfg .model_path )
132125 rewards ,ma_rewards = eval (cfg ,env ,agent )
133- save_results (rewards ,ma_rewards ,tag = 'eval' ,path = cfg .result_path )
134- plot_rewards (rewards ,ma_rewards ,tag = "eval" ,env = cfg .env ,algo = cfg .algo ,path = cfg .result_path )
126+ save_results (rewards ,ma_rewards ,tag = 'eval' ,path = cfg .result_path )
127+ plot_rewards_cn (rewards ,ma_rewards ,tag = "eval" ,env = cfg .env ,algo = cfg .algo ,path = cfg .result_path )
135128
0 commit comments