Skip to content

Commit eae455b

Browse files
authored
Tutorial for enn-ppo (#230)
1 parent cc9b679 commit eae455b

File tree

17 files changed

+878
-508
lines changed

17 files changed

+878
-508
lines changed

enn_ppo/enn_ppo/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .load_checkpoint import load_agent, load_checkpoint
2+
3+
__all__ = [
4+
"load_checkpoint",
5+
"load_agent",
6+
]

enn_ppo/enn_ppo/agent.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
import torch
66
from ragged_buffer import RaggedBufferBool, RaggedBufferF32, RaggedBufferI64
77

8-
from entity_gym.environment import VecActionMask
8+
import entity_gym.agent
9+
from entity_gym.environment import Action, Observation, VecActionMask
10+
from entity_gym.environment.env_list import action_index_to_actions
11+
from entity_gym.environment.environment import ActionType
12+
from entity_gym.environment.vec_env import batch_obs
913
from entity_gym.simple_trace import Tracer
14+
from rogue_net.rogue_net import RogueNet
1015

1116

1217
class PPOAgent(Protocol):
@@ -41,3 +46,26 @@ def get_auxiliary_head(
4146
tracer: Tracer,
4247
) -> torch.Tensor:
4348
...
49+
50+
51+
class RogueNetAgent(entity_gym.agent.Agent):
52+
def __init__(self, agent: RogueNet):
53+
self.agent = agent
54+
55+
def act(self, obs: Observation) -> Tuple[Dict[ActionType, Action], float]:
56+
vec_obs = batch_obs([obs], self.agent.obs_space, self.agent.action_space)
57+
with torch.no_grad():
58+
act_indices, _, _, _, aux, logits = self.agent.get_action_and_auxiliary(
59+
vec_obs.features,
60+
vec_obs.visible,
61+
vec_obs.action_masks,
62+
tracer=Tracer(False),
63+
)
64+
actions = action_index_to_actions(
65+
self.agent.obs_space,
66+
self.agent.action_space,
67+
act_indices,
68+
obs,
69+
probs={k: l.exp().cpu().numpy() for k, l in logits.items()},
70+
)
71+
return actions, float(aux["value"].item())

enn_ppo/enn_ppo/load_checkpoint.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from hyperstate import StateManager
2+
3+
from .config import TrainConfig
4+
from .train import State, initialize
5+
6+
7+
def load_checkpoint(path: str) -> StateManager[TrainConfig, State]:
8+
return StateManager(TrainConfig, State, initialize, init_path=path)
9+
10+
11+
def load_agent(path: str) -> State:
12+
return StateManager(TrainConfig, State, initialize, init_path=path).state

enn_ppo/enn_ppo/train.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
from typing import Any, Callable, Dict, Mapping, Optional, Type, Union
1010

11+
import click
1112
import hyperstate
1213
import numpy as np
1314
import torch
@@ -42,12 +43,10 @@ def serialize(self) -> Any:
4243
def deserialize(
4344
clz, state_dict: Any, config: TrainConfig, state: "State", ctx: Dict[str, Any]
4445
) -> "SerializableRogueNet":
45-
obs_space: ObsSpace = ctx["obs_space"]
46-
action_space: Dict[ActionType, ActionSpace] = ctx["action_space"]
4746
net = SerializableRogueNet(
4847
config.net,
49-
obs_space,
50-
action_space,
48+
state.obs_space,
49+
state.action_space,
5150
regression_heads={"value": 1},
5251
)
5352
net.load_state_dict(state_dict)
@@ -114,6 +113,8 @@ class State(hyperstate.Lazy):
114113
value_function: Optional[SerializableRogueNet]
115114
optimizer: SerializableAdamW
116115
vf_optimizer: Optional[SerializableAdamW]
116+
obs_space: ObsSpace
117+
action_space: Dict[str, ActionSpace]
117118

118119

119120
def train(
@@ -358,26 +359,6 @@ def _run_eval() -> None:
358359
writer.add_scalar(f"{name}.max", value.max, global_step)
359360
writer.add_scalar(f"{name}.min", value.min, global_step)
360361
writer.add_scalar(f"{name}.count", value.count, global_step)
361-
# Double log these to remain compatible with old naming scheme
362-
# TODO: remove before release
363-
writer.add_scalar(
364-
"charts/episodic_return",
365-
metrics["episodic_reward"].mean,
366-
global_step,
367-
)
368-
writer.add_scalar(
369-
"charts/episodic_length",
370-
metrics["episode_length"].mean,
371-
global_step,
372-
)
373-
writer.add_scalar(
374-
"charts/episodes", metrics["episodic_reward"].count, global_step
375-
)
376-
writer.add_scalar("meanrew", metrics["reward"].mean, global_step)
377-
378-
print(
379-
f"global_step={global_step} {' '.join(f'{name}={value.mean}' for name, value in metrics.items())}"
380-
)
381362

382363
values = rollout.values
383364
actions = rollout.actions
@@ -571,9 +552,41 @@ def _run_eval() -> None:
571552
np.sum(_actions == i).item() / len(_actions),
572553
global_step,
573554
)
574-
print(
575-
"SPS:", int((global_step - initial_step) / (time.time() - start_time))
555+
556+
fps = (global_step - initial_step) / (time.time() - start_time)
557+
digits = int(np.ceil(np.log10(cfg.total_timesteps)))
558+
episodic_reward = metrics["episodic_reward"].mean
559+
episode_length = metrics["episode_length"].mean
560+
episode_count = metrics["episode_length"].count
561+
mean_reward = metrics["reward"].mean
562+
563+
def green(s: str) -> str:
564+
return click.style(s, fg="cyan")
565+
566+
def estyle(f: float) -> str:
567+
return click.style(f"{f:.2e}", fg="cyan")
568+
569+
def fstyle(f: float) -> str:
570+
return click.style(f"{f:5.2f}", fg="cyan")
571+
572+
def tstyle(s: str) -> str:
573+
return s
574+
575+
def symstyle(s: str) -> str:
576+
return click.style(s, fg="white", bold=True)
577+
578+
# fmt: off
579+
click.echo(
580+
green(f"{global_step:>{digits}}") + symstyle("/") + green(f"{cfg.total_timesteps} ")
581+
+ f"{symstyle('|')} {tstyle('meanrew')} {estyle(mean_reward)} "
582+
+ f"{symstyle('|')} {tstyle('explained_var')} {fstyle(explained_var.item())} "
583+
+ f"{symstyle('|')} {tstyle('entropy')} {fstyle(entropy_loss.item())} "
584+
+ f"{symstyle('|')} {tstyle('episodic_reward')} {estyle(episodic_reward)} "
585+
+ f"{symstyle('|')} {tstyle('episode_length')} {estyle(episode_length)} "
586+
+ f"{symstyle('|')} {tstyle('episodes')} {green(str(episode_count))} "
587+
+ f"{symstyle('|')} {tstyle('fps')} {green(str(int(fps)))}"
576588
)
589+
# fmt: on
577590
writer.add_scalar(
578591
"charts/SPS",
579592
int((global_step - initial_step) / (time.time() - start_time)),
@@ -679,6 +692,8 @@ def initialize(cfg: TrainConfig, ctx: Dict[str, Any]) -> State:
679692
value_function=value_function,
680693
optimizer=optimizer,
681694
vf_optimizer=vf_optimizer,
695+
obs_space=ctx["obs_space"],
696+
action_space=ctx["action_space"],
682697
)
683698

684699

0 commit comments

Comments
 (0)