Skip to content

Commit 97dba9b

Browse files
authored
Merge pull request MorvanZhou#136 from Gaoee/master
fixed a 'runs slowly gradually' problem
2 parents 967c829 + cd3b606 commit 97dba9b

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

experiments/Robot_arm/DDPG.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(self, sess, action_dim, action_bound, learning_rate, t_replace_iter
7474

7575
self.e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/eval_net')
7676
self.t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/target_net')
77+
self.replace = [tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)]
7778

7879
def _build_net(self, s, scope, trainable):
7980
with tf.variable_scope(scope):
@@ -97,7 +98,7 @@ def _build_net(self, s, scope, trainable):
9798
def learn(self, s): # batch update
9899
self.sess.run(self.train_op, feed_dict={S: s})
99100
if self.t_replace_counter % self.t_replace_iter == 0:
100-
self.sess.run([tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)])
101+
self.sess.run(self.replace)
101102
self.t_replace_counter += 1
102103

103104
def choose_action(self, s):
@@ -145,6 +146,7 @@ def __init__(self, sess, state_dim, action_dim, learning_rate, gamma, t_replace_
145146

146147
with tf.variable_scope('a_grad'):
147148
self.a_grads = tf.gradients(self.q, a)[0] # tensor of gradients of each sample (None, a_dim)
149+
self.replace = [tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)]
148150

149151
def _build_net(self, s, a, scope, trainable):
150152
with tf.variable_scope(scope):
@@ -170,7 +172,7 @@ def _build_net(self, s, a, scope, trainable):
170172
def learn(self, s, a, r, s_):
171173
self.sess.run(self.train_op, feed_dict={S: s, self.a: a, R: r, S_: s_})
172174
if self.t_replace_counter % self.t_replace_iter == 0:
173-
self.sess.run([tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)])
175+
self.sess.run(self.replace)
174176
self.t_replace_counter += 1
175177

176178

@@ -273,4 +275,4 @@ def eval():
273275
if LOAD:
274276
eval()
275277
else:
276-
train()
278+
train()

0 commit comments

Comments
 (0)