Skip to content

Commit f29bf92

Browse files
authored
Checkpointing (#221)
1 parent c174561 commit f29bf92

File tree

14 files changed

+749
-521
lines changed

14 files changed

+749
-521
lines changed

configs/codecraft/arena_medium.ron

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Achieves ~0.4 against eval opponent (old baseline: ~0.8-0.95): https://wandb.ai/entity-neural-network/enn-ppo/reports/Arena-Medium-baseline--VmlldzoxNzgwMTM1
22
ExperimentConfig(
3+
version: 1,
34
env: (
45
id: "CodeCraft",
56
kwargs: "{\"objective\": \"ARENA_MEDIUM\", \"hardness\": 1.0, \"win_bonus\": 2.0, \"hidden_obs\": true}",

configs/xprun/train.ron

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,6 @@ XpV0(
9696
env_secrets: {
9797
"WANDB_API_KEY": "wandb-api-key",
9898
},
99-
volumes: {
100-
"/mnt/a/Dropbox/artifacts/xprun": "/mnt/xprun",
101-
},
10299
)
103100
}
104101
)

configs/xprun/trainbc.ron

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ XpV0(
6969
env_secrets: {
7070
"WANDB_API_KEY": "wandb-api-key",
7171
},
72-
volumes: {
73-
"/mnt/a/Dropbox/artifacts/xprun": "/mnt/xprun",
74-
},
7572
)
7673
}
7774
)

configs/xprun/traincc.ron

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,6 @@ XpV0(
9696
env_secrets: {
9797
"WANDB_API_KEY": "wandb-api-key",
9898
},
99-
volumes: {
100-
"/mnt/a/Dropbox/artifacts/xprun": "/mnt/xprun",
101-
},
10299
),
103100

104101
"codecraftserver": (

enn_ppo/enn_ppo/tests/test_training.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
1+
from hyperstate import StateManager
2+
13
from enn_ppo.config import RolloutConfig
2-
from enn_ppo.train import EnvConfig, OptimizerConfig, PPOConfig, TrainConfig, _train
4+
from enn_ppo.train import (
5+
EnvConfig,
6+
OptimizerConfig,
7+
PPOConfig,
8+
State,
9+
TrainConfig,
10+
initialize,
11+
train,
12+
)
13+
from entity_gym.examples import ENV_REGISTRY
314
from rogue_net.relpos_encoding import RelposEncodingConfig
415
from rogue_net.rogue_net import RogueNetConfig
516

617

18+
def _train(cfg: TrainConfig) -> float:
19+
sm = StateManager(TrainConfig, State, initialize, None)
20+
sm._config = cfg
21+
return train(sm, ENV_REGISTRY[cfg.env.id])
22+
23+
724
def test_multi_armed_bandit() -> None:
825
cfg = TrainConfig(
926
total_timesteps=500,

0 commit comments

Comments
 (0)