11# adapted from https://github.com/vwxyzjn/cleanrl
2+ import inspect
23import json
34import os
45import random
56import time
67from dataclasses import asdict
78from pathlib import Path
8- from typing import Any , Callable , Dict , Mapping , Optional , Type
9+ from typing import Any , Callable , Dict , Mapping , Optional , Type , Union
910
1011import hyperstate
1112import numpy as np
2930from entity_gym .simple_trace import Tracer
3031from rogue_net .rogue_net import RogueNet
3132
33+ EnvFactory = Callable [[EnvConfig , int , int , int ], VecEnv ]
34+
3235
3336class SerializableRogueNet (RogueNet , hyperstate .Serializable [TrainConfig , "State" ]):
3437 def serialize (self ) -> Any :
@@ -38,11 +41,12 @@ def serialize(self) -> Any:
3841 def deserialize (
3942 clz , state_dict : Any , config : TrainConfig , state : "State" , ctx : Dict [str , Any ]
4043 ) -> "SerializableRogueNet" :
41- env_cls : Type [Environment ] = ctx ["env_cls" ]
44+ obs_space : ObsSpace = ctx ["obs_space" ]
45+ action_space : Dict [ActionType , ActionSpace ] = ctx ["action_space" ]
4246 net = SerializableRogueNet (
4347 config .net ,
44- env_cls . obs_space () ,
45- env_cls . action_space () ,
48+ obs_space ,
49+ action_space ,
4650 regression_heads = {"value" : 1 },
4751 )
4852 net .load_state_dict (state_dict )
@@ -109,8 +113,7 @@ class State(hyperstate.Lazy):
109113
110114def train (
111115 state_manager : StateManager [TrainConfig , State ],
112- env_cls : Type [Environment ],
113- create_env : Optional [Callable [[EnvConfig , int , int , int ], VecEnv ]] = None ,
116+ env : Union [Type [Environment ], EnvFactory ],
114117 create_opponent : Optional [
115118 Callable [[str , ObsSpace , Mapping [str , ActionSpace ], torch .device ], PPOAgent ]
116119 ] = None ,
@@ -185,8 +188,10 @@ def train(
185188 torch .manual_seed (cfg .seed )
186189 torch .backends .cudnn .deterministic = cfg .torch_deterministic
187190
188- if create_env is None :
189- create_env = _env_factory (env_cls )
191+ if inspect .isclass (env ) and issubclass (env , Environment ): # type: ignore
192+ create_env : EnvFactory = _env_factory (env ) # type: ignore
193+ else :
194+ create_env = env # type: ignore
190195 envs : VecEnv = AddMetricsWrapper (
191196 create_env (
192197 cfg .env ,
@@ -195,10 +200,11 @@ def train(
195200 rank * cfg .rollout .num_envs // parallelism ,
196201 ),
197202 )
198- obs_space = env_cls .obs_space ()
199- action_space = env_cls .action_space ()
203+ obs_space = envs .obs_space ()
204+ action_space = envs .action_space ()
200205
201- state_manager .set_deserialize_ctx ("env_cls" , env_cls )
206+ state_manager .set_deserialize_ctx ("obs_space" , obs_space )
207+ state_manager .set_deserialize_ctx ("action_space" , action_space )
202208 state_manager .set_deserialize_ctx ("agent" , agent )
203209 state = state_manager .state
204210 if state .step > 0 :
@@ -274,7 +280,6 @@ def _run_eval() -> None:
274280 cfg .eval ,
275281 cfg .env ,
276282 cfg .rollout ,
277- env_cls ,
278283 create_env ,
279284 create_opponent or create_random_opponent ,
280285 agent ,
@@ -616,11 +621,13 @@ def gradient_allreduce(model: Any) -> None:
616621 offset += param .numel ()
617622
618623
619- def _create_agent (cfg : TrainConfig , env_cls : Type [Environment ]) -> SerializableRogueNet :
624+ def _create_agent (
625+ cfg : TrainConfig , obs_space : ObsSpace , action_space : Dict [ActionType , ActionSpace ]
626+ ) -> SerializableRogueNet :
620627 return SerializableRogueNet (
621628 cfg .net ,
622- env_cls . obs_space () ,
623- env_cls . action_space () ,
629+ obs_space ,
630+ action_space ,
624631 regression_heads = {"value" : 1 },
625632 )
626633
@@ -637,7 +644,7 @@ def initialize(cfg: TrainConfig, ctx: Dict[str, Any]) -> State:
637644 if ctx .get ("agent" ) is not None :
638645 agent : SerializableRogueNet = ctx ["agent" ]
639646 else :
640- agent = _create_agent (cfg , ctx ["env_cls " ])
647+ agent = _create_agent (cfg , ctx ["obs_space" ], ctx [ "action_space " ])
641648 optimizer = SerializableAdamW (
642649 agent .parameters (),
643650 lr = cfg .optim .lr ,
@@ -647,7 +654,7 @@ def initialize(cfg: TrainConfig, ctx: Dict[str, Any]) -> State:
647654
648655 if cfg .vf_net is not None :
649656 value_function : Optional [SerializableRogueNet ] = _create_agent (
650- cfg , ctx ["env_cls " ]
657+ cfg , ctx ["obs_space" ], ctx [ "action_space " ]
651658 )
652659 vf_optimizer : Optional [SerializableAdamW ] = SerializableAdamW (
653660 value_function .parameters (), # type: ignore
0 commit comments