Skip to content

Commit 4f86139

Browse files
authored
Misc bugfixes (#237)
1 parent 315496f commit 4f86139

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

enn_ppo/enn_ppo/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,14 @@ def train(
166166
xp_info.xp_def.project,
167167
xp_info.sanitized_name + "-" + xp_info.id,
168168
)
169+
id = xp_info.id
169170
Path(str(out_dir)).mkdir(parents=True, exist_ok=True)
170171

171172
init_process(xp_info)
172173
rank = xp_info.replica_index
173174
parallelism = xp_info.replicas()
174175
else:
176+
id = None
175177
out_dir = None
176178
rank = 0
177179
parallelism = 1
@@ -253,6 +255,7 @@ def train(
253255
name=run_name,
254256
save_code=True,
255257
dir=data_dir,
258+
id=id,
256259
)
257260
wandb.watch(agent)
258261

enn_zoo/enn_zoo/procgen_env/base_env.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,22 +118,21 @@ def observe(self) -> Observation:
118118
[0],
119119
)
120120
}
121+
total = 1
121122
for type_id, name in self._entity_types().items():
122123
feats = state.entities[state.entities[:, 6] == type_id]
123124
if feats.shape[0] > 0:
124125
feats = np.concatenate(
125126
[feats, global_feats.repeat(feats.shape[0], axis=0)],
126127
axis=1,
127128
)
129+
total += feats.shape[0]
128130
else:
129131
feats = np.zeros(
130132
(0, feats.shape[1] + global_feats.shape[1]), dtype=np.float32
131133
)
132134
entities[name] = feats
133-
assert (
134-
sum(e.features.shape[0] for e in entities.values()) # type: ignore
135-
== state.entities.shape[0]
136-
)
135+
assert total == state.entities.shape[0]
137136

138137
return Observation(
139138
entities=entities,

rogue_net/rogue_net/relpos_encoding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def __init__(
172172
if self.value_gate is not None:
173173
self.value_gate_proj = nn.Linear(dmodel, dhead)
174174
self.cached_rkvs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
175+
self.global_entity = "__global__" in obs_space.entities
175176

176177
def relattn_logits(self, queries: torch.Tensor) -> torch.Tensor:
177178
assert self.cached_rkvs is not None
@@ -213,6 +214,8 @@ def keys_values(
213214
# Type of each entity
214215
entity_type: torch.Tensor,
215216
) -> Tuple[torch.Tensor, torch.Tensor]:
217+
if "__global__" in x and not self.global_entity:
218+
x = {k: v for k, v in x.items() if k != "__global__"}
216219
relative_positions, entity_type = self._relative_positions(
217220
x, index_map, packpad_index, shape, entity_type
218221
)

0 commit comments

Comments
 (0)