Skip to content

Commit ad92a0d

Browse files
authored
Make obs_space and action_space methods (#224)
1 parent 3a82794 commit ad92a0d

36 files changed

+301
-367
lines changed

enn_ppo/enn_ppo/eval.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List, Mapping, Optional, Tuple, Type, Union
1+
from typing import Callable, List, Mapping, Optional, Tuple, Union
22

33
import numpy as np
44
import numpy.typing as npt
@@ -19,7 +19,6 @@ def run_eval(
1919
cfg: EvalConfig,
2020
env_cfg: EnvConfig,
2121
rollout: RolloutConfig,
22-
env_cls: Type[Environment],
2322
create_env: Callable[[EnvConfig, int, int, int], VecEnv],
2423
create_opponent: Callable[
2524
[str, ObsSpace, Mapping[str, ActionSpace], torch.device], PPOAgent
@@ -35,8 +34,18 @@ def run_eval(
3534
# TODO: metrics are biased towards short episodes
3635
processes = cfg.processes or rollout.processes
3736
num_envs = cfg.num_envs or rollout.num_envs
38-
obs_space = env_cls.obs_space()
39-
action_space = env_cls.action_space()
37+
38+
envs: VecEnv = AddMetricsWrapper(
39+
create_env(
40+
cfg.env or env_cfg,
41+
num_envs // parallelism,
42+
processes,
43+
rank * num_envs // parallelism,
44+
),
45+
metric_filter,
46+
)
47+
obs_space = envs.obs_space()
48+
action_space = envs.action_space()
4049

4150
assert num_envs % parallelism == 0, (
4251
"Number of eval environments must be divisible by parallelism: "
@@ -62,16 +71,6 @@ def run_eval(
6271
else:
6372
agents = agent
6473

65-
envs: VecEnv = AddMetricsWrapper(
66-
create_env(
67-
cfg.env or env_cfg,
68-
num_envs // parallelism,
69-
processes,
70-
rank * num_envs // parallelism,
71-
),
72-
metric_filter,
73-
)
74-
7574
if cfg.capture_samples:
7675
envs = SampleRecordingVecEnv(
7776
envs, cfg.capture_samples, cfg.capture_samples_subsample

enn_ppo/enn_ppo/train.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# adapted from https://github.com/vwxyzjn/cleanrl
2+
import inspect
23
import json
34
import os
45
import random
56
import time
67
from dataclasses import asdict
78
from pathlib import Path
8-
from typing import Any, Callable, Dict, Mapping, Optional, Type
9+
from typing import Any, Callable, Dict, Mapping, Optional, Type, Union
910

1011
import hyperstate
1112
import numpy as np
@@ -29,6 +30,8 @@
2930
from entity_gym.simple_trace import Tracer
3031
from rogue_net.rogue_net import RogueNet
3132

33+
EnvFactory = Callable[[EnvConfig, int, int, int], VecEnv]
34+
3235

3336
class SerializableRogueNet(RogueNet, hyperstate.Serializable[TrainConfig, "State"]):
3437
def serialize(self) -> Any:
@@ -38,11 +41,12 @@ def serialize(self) -> Any:
3841
def deserialize(
3942
clz, state_dict: Any, config: TrainConfig, state: "State", ctx: Dict[str, Any]
4043
) -> "SerializableRogueNet":
41-
env_cls: Type[Environment] = ctx["env_cls"]
44+
obs_space: ObsSpace = ctx["obs_space"]
45+
action_space: Dict[ActionType, ActionSpace] = ctx["action_space"]
4246
net = SerializableRogueNet(
4347
config.net,
44-
env_cls.obs_space(),
45-
env_cls.action_space(),
48+
obs_space,
49+
action_space,
4650
regression_heads={"value": 1},
4751
)
4852
net.load_state_dict(state_dict)
@@ -109,8 +113,7 @@ class State(hyperstate.Lazy):
109113

110114
def train(
111115
state_manager: StateManager[TrainConfig, State],
112-
env_cls: Type[Environment],
113-
create_env: Optional[Callable[[EnvConfig, int, int, int], VecEnv]] = None,
116+
env: Union[Type[Environment], EnvFactory],
114117
create_opponent: Optional[
115118
Callable[[str, ObsSpace, Mapping[str, ActionSpace], torch.device], PPOAgent]
116119
] = None,
@@ -185,8 +188,10 @@ def train(
185188
torch.manual_seed(cfg.seed)
186189
torch.backends.cudnn.deterministic = cfg.torch_deterministic
187190

188-
if create_env is None:
189-
create_env = _env_factory(env_cls)
191+
if inspect.isclass(env) and issubclass(env, Environment): # type: ignore
192+
create_env: EnvFactory = _env_factory(env) # type: ignore
193+
else:
194+
create_env = env # type: ignore
190195
envs: VecEnv = AddMetricsWrapper(
191196
create_env(
192197
cfg.env,
@@ -195,10 +200,11 @@ def train(
195200
rank * cfg.rollout.num_envs // parallelism,
196201
),
197202
)
198-
obs_space = env_cls.obs_space()
199-
action_space = env_cls.action_space()
203+
obs_space = envs.obs_space()
204+
action_space = envs.action_space()
200205

201-
state_manager.set_deserialize_ctx("env_cls", env_cls)
206+
state_manager.set_deserialize_ctx("obs_space", obs_space)
207+
state_manager.set_deserialize_ctx("action_space", action_space)
202208
state_manager.set_deserialize_ctx("agent", agent)
203209
state = state_manager.state
204210
if state.step > 0:
@@ -274,7 +280,6 @@ def _run_eval() -> None:
274280
cfg.eval,
275281
cfg.env,
276282
cfg.rollout,
277-
env_cls,
278283
create_env,
279284
create_opponent or create_random_opponent,
280285
agent,
@@ -616,11 +621,13 @@ def gradient_allreduce(model: Any) -> None:
616621
offset += param.numel()
617622

618623

619-
def _create_agent(cfg: TrainConfig, env_cls: Type[Environment]) -> SerializableRogueNet:
624+
def _create_agent(
625+
cfg: TrainConfig, obs_space: ObsSpace, action_space: Dict[ActionType, ActionSpace]
626+
) -> SerializableRogueNet:
620627
return SerializableRogueNet(
621628
cfg.net,
622-
env_cls.obs_space(),
623-
env_cls.action_space(),
629+
obs_space,
630+
action_space,
624631
regression_heads={"value": 1},
625632
)
626633

@@ -637,7 +644,7 @@ def initialize(cfg: TrainConfig, ctx: Dict[str, Any]) -> State:
637644
if ctx.get("agent") is not None:
638645
agent: SerializableRogueNet = ctx["agent"]
639646
else:
640-
agent = _create_agent(cfg, ctx["env_cls"])
647+
agent = _create_agent(cfg, ctx["obs_space"], ctx["action_space"])
641648
optimizer = SerializableAdamW(
642649
agent.parameters(),
643650
lr=cfg.optim.lr,
@@ -647,7 +654,7 @@ def initialize(cfg: TrainConfig, ctx: Dict[str, Any]) -> State:
647654

648655
if cfg.vf_net is not None:
649656
value_function: Optional[SerializableRogueNet] = _create_agent(
650-
cfg, ctx["env_cls"]
657+
cfg, ctx["obs_space"], ctx["action_space"]
651658
)
652659
vf_optimizer: Optional[SerializableAdamW] = SerializableAdamW(
653660
value_function.parameters(), # type: ignore

0 commit comments

Comments
 (0)