Skip to content

Commit 5a3c2de

Browse files
authored
Fix bug in MultiSnake environment (#215)
1 parent 09f225e commit 5a3c2de

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Achieves > 0.98 episodic return: https://wandb.ai/entity-neural-network/enn-ppo/reports/MultiSnake-2-snakes-11-length--VmlldzoxNzgwNzEw
2+
ExperimentConfig(
3+
version: 0,
4+
env: (
5+
id: "MultiSnake",
6+
kwargs: "{\"num_snakes\": 2, \"max_snake_length\": 11}",
7+
),
8+
rollout: (
9+
num_envs: 512,
10+
steps: 128,
11+
processes: 16,
12+
),
13+
total_timesteps: 100000000,
14+
net: (
15+
d_model: 128,
16+
n_layer: 2,
17+
relpos_encoding: (
18+
extent: [10, 10],
19+
position_features: ["x", "y"],
20+
),
21+
),
22+
optim: (
23+
bs: 32768,
24+
lr: 0.005,
25+
max_grad_norm: 10,
26+
micro_bs: 8192,
27+
),
28+
ppo: (
29+
ent_coef: 0.03,
30+
gamma: 0.99,
31+
anneal_entropy: true,
32+
),
33+
)

entity_gym/entity_gym/environment/environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class Entity:
123123

124124
@dataclass
125125
class ObsSpace:
126-
entities: Dict[str, Entity]
126+
entities: Dict[EntityType, Entity]
127127

128128

129129
@dataclass

entity_gym/entity_gym/examples/multi_snake.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def act(self, action: Mapping[str, Action]) -> Observation:
125125
self.step += 1
126126
move_action = action["move"]
127127
self.last_scores = deepcopy(self.scores)
128+
food_to_spawn = []
128129
assert isinstance(move_action, CategoricalAction)
129130
for id, move in move_action.items():
130131
snake = self.snakes[id]
@@ -153,7 +154,8 @@ def act(self, action: Mapping[str, Action]) -> Observation:
153154
1.0 / (self.max_snake_length - 1) / self.num_snakes
154155
)
155156
self.food.pop(i)
156-
self._spawn_food(snake.color)
157+
# Don't spawn food immediately since it might spawn in front of another snake that hasn't moved yet
158+
food_to_spawn.append(snake.color)
157159
break
158160
if not ate_food:
159161
snake.segments = snake.segments[1:]
@@ -166,6 +168,8 @@ def act(self, action: Mapping[str, Action]) -> Observation:
166168
]
167169
):
168170
game_over = True
171+
for color in food_to_spawn:
172+
self._spawn_food(color)
169173
if self.step >= self.max_steps:
170174
game_over = True
171175
return self._observe(done=game_over)

0 commit comments

Comments
 (0)