Skip to content

Commit 5c93b38

Browse files
author
Ervin Teng
committed
GAIL no longer uses placeholders from Policy
1 parent c2a0125 commit 5c93b38

File tree

3 files changed

+47
-36
lines changed

3 files changed

+47
-36
lines changed

ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def evaluate(
5757
return []
5858

5959
feed_dict = {}
60-
6160
for i, _ in enumerate(current_info.visual_observations):
6261
feed_dict[self.model.visual_in[i]] = current_info.visual_observations[i]
6362
if self.policy.use_vec_obs:

ml-agents/mlagents/trainers/components/reward_signals/gail/model.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,29 @@ def make_inputs(self) -> None:
6363

6464
if self.policy_model.brain.vector_action_space_type == "continuous":
6565
action_length = self.policy_model.act_size[0]
66+
self.action_in_policy = tf.placeholder(
67+
shape=[None, action_length], dtype=tf.float32
68+
)
6669
self.action_in_expert = tf.placeholder(
6770
shape=[None, action_length], dtype=tf.float32
6871
)
6972
self.expert_action = tf.identity(self.action_in_expert)
73+
self.policy_action = tf.identity(self.action_in_policy)
7074
else:
7175
action_length = len(self.policy_model.act_size)
76+
self.action_in_policy = tf.placeholder(
77+
shape=[None, action_length], dtype=tf.int32
78+
)
7279
self.action_in_expert = tf.placeholder(
7380
shape=[None, action_length], dtype=tf.int32
7481
)
82+
self.policy_action = tf.concat(
83+
[
84+
tf.one_hot(self.action_in_policy[:, i], act_size)
85+
for i, act_size in enumerate(self.policy_model.act_size)
86+
],
87+
axis=1,
88+
)
7589
self.expert_action = tf.concat(
7690
[
7791
tf.one_hot(self.action_in_expert[:, i], act_size)
@@ -84,6 +98,9 @@ def make_inputs(self) -> None:
8498
encoded_expert_list = []
8599

86100
if self.policy_model.vec_obs_size > 0:
101+
self.vector_in = tf.placeholder(
102+
shape=[None, self.policy_model.vec_obs_size], dtype=tf.float32
103+
)
87104
self.obs_in_expert = tf.placeholder(
88105
shape=[None, self.policy_model.vec_obs_size], dtype=tf.float32
89106
)
@@ -92,26 +109,33 @@ def make_inputs(self) -> None:
92109
self.policy_model.normalize_vector_obs(self.obs_in_expert)
93110
)
94111
encoded_policy_list.append(
95-
self.policy_model.normalize_vector_obs(self.policy_model.vector_in)
112+
self.policy_model.normalize_vector_obs(self.vector_in)
96113
)
97114
else:
98115
encoded_expert_list.append(self.obs_in_expert)
99-
encoded_policy_list.append(self.policy_model.vector_in)
116+
encoded_policy_list.append(self.vector_in)
100117

101118
if self.policy_model.vis_obs_size > 0:
102119
self.expert_visual_in: List[tf.Tensor] = []
120+
self.visual_in: List[tf.Tensor] = []
103121
visual_policy_encoders = []
104122
visual_expert_encoders = []
105123
for i in range(self.policy_model.vis_obs_size):
106-
# Create input ops for next (t+1) visual observations.
124+
# Create input ops for visual observations.
107125
visual_input = self.policy_model.create_visual_input(
108126
self.policy_model.brain.camera_resolutions[i],
109-
name="visual_observation_" + str(i),
127+
name="gail_visual_observation_" + str(i),
128+
)
129+
self.visual_in.append(visual_input)
130+
# Create input ops for next (t+1) visual observations.
131+
ex_visual_input = self.policy_model.create_visual_input(
132+
self.policy_model.brain.camera_resolutions[i],
133+
name="expert_visual_observation_" + str(i),
110134
)
111-
self.expert_visual_in.append(visual_input)
135+
self.expert_visual_in.append(ex_visual_input)
112136

113137
encoded_policy_visual = self.policy_model.create_visual_observation_encoder(
114-
self.policy_model.visual_in[i],
138+
self.visual_in[i],
115139
self.encoding_size,
116140
LearningModel.swish,
117141
1,
@@ -217,10 +241,7 @@ def create_network(self) -> None:
217241
self.encoded_expert, self.expert_action, self.done_expert, reuse=False
218242
)
219243
self.policy_estimate, self.z_mean_policy, _ = self.create_encoder(
220-
self.encoded_policy,
221-
self.policy_model.selected_actions,
222-
self.done_policy,
223-
reuse=True,
244+
self.encoded_policy, self.policy_action, self.done_policy, reuse=True
224245
)
225246
self.discriminator_score = tf.reshape(
226247
self.policy_estimate, [-1], name="GAIL_reward"
@@ -233,11 +254,7 @@ def create_gradient_magnitude(self) -> tf.Tensor:
233254
for off-policy. Compute gradients w.r.t randomly interpolated input.
234255
"""
235256
expert = [self.encoded_expert, self.expert_action, self.done_expert]
236-
policy = [
237-
self.encoded_policy,
238-
self.policy_model.selected_actions,
239-
self.done_policy,
240-
]
257+
policy = [self.encoded_policy, self.policy_action, self.done_policy]
241258
interp = []
242259
for _expert_in, _policy_in in zip(expert, policy):
243260
alpha = tf.random_uniform(tf.shape(_expert_in))

ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,19 @@ def evaluate(
6767
) -> RewardSignalResult:
6868
if len(current_info.agents) == 0:
6969
return []
70-
71-
feed_dict: Dict[tf.Tensor, Any] = {
72-
self.policy.model.batch_size: len(next_info.vector_observations),
73-
self.policy.model.sequence_length: 1,
74-
}
70+
feed_dict: Dict[tf.Tensor, Any] = {}
7571
if self.model.use_vail:
7672
feed_dict[self.model.use_noise] = [0]
73+
for i, _ in enumerate(current_info.visual_observations):
74+
feed_dict[self.model.visual_in[i]] = current_info.visual_observations[i]
75+
if self.policy.use_vec_obs:
76+
feed_dict[self.model.vector_in] = current_info.vector_observations
7777

78-
feed_dict = self.policy.fill_eval_dict(feed_dict, brain_info=current_info)
7978
feed_dict[self.model.done_policy] = np.reshape(next_info.local_done, [-1, 1])
8079
if self.policy.use_continuous_act:
81-
feed_dict[
82-
self.policy.model.selected_actions
83-
] = next_info.previous_vector_actions
80+
feed_dict[self.model.action_in_policy] = next_info.previous_vector_actions
8481
else:
85-
feed_dict[
86-
self.policy.model.action_holder
87-
] = next_info.previous_vector_actions
82+
feed_dict[self.model.action_in_policy] = next_info.previous_vector_actions
8883
unscaled_reward = self.policy.sess.run(
8984
self.model.intrinsic_reward, feed_dict=feed_dict
9085
)
@@ -134,30 +129,30 @@ def prepare_update(
134129
feed_dict[self.model.use_noise] = [1]
135130

136131
if self.policy.use_continuous_act:
137-
feed_dict[self.policy.model.selected_actions] = mini_batch_policy[
132+
feed_dict[self.model.action_in_policy] = mini_batch_policy[
138133
"actions"
139134
].reshape([-1, self.policy.model.act_size[0]])
140135
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"].reshape(
141136
[-1, self.policy.model.act_size[0]]
142137
)
143138
else:
144-
feed_dict[self.policy.model.action_holder] = mini_batch_policy[
139+
feed_dict[self.model.action_in_policy] = mini_batch_policy[
145140
"actions"
146141
].reshape([-1, len(self.policy.model.act_size)])
147142
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"].reshape(
148143
[-1, len(self.policy.model.act_size)]
149144
)
150145

151146
if self.policy.use_vis_obs > 0:
152-
for i in range(len(self.policy.model.visual_in)):
147+
for i in range(len(self.model.visual_in)):
153148
policy_obs = mini_batch_policy["visual_obs%d" % i]
154149
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
155150
(_batch, _seq, _w, _h, _c) = policy_obs.shape
156-
feed_dict[self.policy.model.visual_in[i]] = policy_obs.reshape(
151+
feed_dict[self.model.visual_in[i]] = policy_obs.reshape(
157152
[-1, _w, _h, _c]
158153
)
159154
else:
160-
feed_dict[self.policy.model.visual_in[i]] = policy_obs
155+
feed_dict[self.model.visual_in[i]] = policy_obs
161156

162157
demo_obs = mini_batch_demo["visual_obs%d" % i]
163158
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
@@ -168,9 +163,9 @@ def prepare_update(
168163
else:
169164
feed_dict[self.model.expert_visual_in[i]] = demo_obs
170165
if self.policy.use_vec_obs:
171-
feed_dict[self.policy.model.vector_in] = mini_batch_policy[
172-
"vector_obs"
173-
].reshape([-1, self.policy.vec_obs_size])
166+
feed_dict[self.model.vector_in] = mini_batch_policy["vector_obs"].reshape(
167+
[-1, self.policy.vec_obs_size]
168+
)
174169
feed_dict[self.model.obs_in_expert] = mini_batch_demo["vector_obs"].reshape(
175170
[-1, self.policy.vec_obs_size]
176171
)

0 commit comments

Comments
 (0)