@@ -74,6 +74,7 @@ def __init__(self, sess, action_dim, action_bound, learning_rate, t_replace_iter
74
74
75
75
self .e_params = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = 'Actor/eval_net' )
76
76
self .t_params = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = 'Actor/target_net' )
77
+ self .replace = [tf .assign (t , e ) for t , e in zip (self .t_params , self .e_params )]
77
78
78
79
def _build_net (self , s , scope , trainable ):
79
80
with tf .variable_scope (scope ):
@@ -97,7 +98,7 @@ def _build_net(self, s, scope, trainable):
97
98
def learn (self , s ): # batch update
98
99
self .sess .run (self .train_op , feed_dict = {S : s })
99
100
if self .t_replace_counter % self .t_replace_iter == 0 :
100
- self .sess .run ([ tf . assign ( t , e ) for t , e in zip ( self .t_params , self . e_params )] )
101
+ self .sess .run (self .replace )
101
102
self .t_replace_counter += 1
102
103
103
104
def choose_action (self , s ):
@@ -145,6 +146,7 @@ def __init__(self, sess, state_dim, action_dim, learning_rate, gamma, t_replace_
145
146
146
147
with tf .variable_scope ('a_grad' ):
147
148
self .a_grads = tf .gradients (self .q , a )[0 ] # tensor of gradients of each sample (None, a_dim)
149
+ self .replace = [tf .assign (t , e ) for t , e in zip (self .t_params , self .e_params )]
148
150
149
151
def _build_net (self , s , a , scope , trainable ):
150
152
with tf .variable_scope (scope ):
@@ -170,7 +172,7 @@ def _build_net(self, s, a, scope, trainable):
170
172
def learn (self , s , a , r , s_ ):
171
173
self .sess .run (self .train_op , feed_dict = {S : s , self .a : a , R : r , S_ : s_ })
172
174
if self .t_replace_counter % self .t_replace_iter == 0 :
173
- self .sess .run ([ tf . assign ( t , e ) for t , e in zip ( self .t_params , self . e_params )] )
175
+ self .sess .run (self .replace )
174
176
self .t_replace_counter += 1
175
177
176
178
@@ -273,4 +275,4 @@ def eval():
273
275
if LOAD :
274
276
eval ()
275
277
else :
276
- train ()
278
+ train ()
0 commit comments