|
8 | 8 | from pathlib import Path |
9 | 9 | from typing import Any, Callable, Dict, Mapping, Optional, Type, Union |
10 | 10 |
|
| 11 | +import click |
11 | 12 | import hyperstate |
12 | 13 | import numpy as np |
13 | 14 | import torch |
@@ -42,12 +43,10 @@ def serialize(self) -> Any: |
42 | 43 | def deserialize( |
43 | 44 | clz, state_dict: Any, config: TrainConfig, state: "State", ctx: Dict[str, Any] |
44 | 45 | ) -> "SerializableRogueNet": |
45 | | - obs_space: ObsSpace = ctx["obs_space"] |
46 | | - action_space: Dict[ActionType, ActionSpace] = ctx["action_space"] |
47 | 46 | net = SerializableRogueNet( |
48 | 47 | config.net, |
49 | | - obs_space, |
50 | | - action_space, |
| 48 | + state.obs_space, |
| 49 | + state.action_space, |
51 | 50 | regression_heads={"value": 1}, |
52 | 51 | ) |
53 | 52 | net.load_state_dict(state_dict) |
@@ -114,6 +113,8 @@ class State(hyperstate.Lazy): |
114 | 113 | value_function: Optional[SerializableRogueNet] |
115 | 114 | optimizer: SerializableAdamW |
116 | 115 | vf_optimizer: Optional[SerializableAdamW] |
| 116 | + obs_space: ObsSpace |
| 117 | + action_space: Dict[str, ActionSpace] |
117 | 118 |
|
118 | 119 |
|
119 | 120 | def train( |
@@ -358,26 +359,6 @@ def _run_eval() -> None: |
358 | 359 | writer.add_scalar(f"{name}.max", value.max, global_step) |
359 | 360 | writer.add_scalar(f"{name}.min", value.min, global_step) |
360 | 361 | writer.add_scalar(f"{name}.count", value.count, global_step) |
361 | | - # Double log these to remain compatible with old naming scheme |
362 | | - # TODO: remove before release |
363 | | - writer.add_scalar( |
364 | | - "charts/episodic_return", |
365 | | - metrics["episodic_reward"].mean, |
366 | | - global_step, |
367 | | - ) |
368 | | - writer.add_scalar( |
369 | | - "charts/episodic_length", |
370 | | - metrics["episode_length"].mean, |
371 | | - global_step, |
372 | | - ) |
373 | | - writer.add_scalar( |
374 | | - "charts/episodes", metrics["episodic_reward"].count, global_step |
375 | | - ) |
376 | | - writer.add_scalar("meanrew", metrics["reward"].mean, global_step) |
377 | | - |
378 | | - print( |
379 | | - f"global_step={global_step} {' '.join(f'{name}={value.mean}' for name, value in metrics.items())}" |
380 | | - ) |
381 | 362 |
|
382 | 363 | values = rollout.values |
383 | 364 | actions = rollout.actions |
@@ -571,9 +552,41 @@ def _run_eval() -> None: |
571 | 552 | np.sum(_actions == i).item() / len(_actions), |
572 | 553 | global_step, |
573 | 554 | ) |
574 | | - print( |
575 | | - "SPS:", int((global_step - initial_step) / (time.time() - start_time)) |
| 555 | + |
| 556 | + fps = (global_step - initial_step) / (time.time() - start_time) |
| 557 | + digits = int(np.ceil(np.log10(cfg.total_timesteps))) |
| 558 | + episodic_reward = metrics["episodic_reward"].mean |
| 559 | + episode_length = metrics["episode_length"].mean |
| 560 | + episode_count = metrics["episode_length"].count |
| 561 | + mean_reward = metrics["reward"].mean |
| 562 | + |
| 563 | + def green(s: str) -> str: |
| 564 | + return click.style(s, fg="cyan") |
| 565 | + |
| 566 | + def estyle(f: float) -> str: |
| 567 | + return click.style(f"{f:.2e}", fg="cyan") |
| 568 | + |
| 569 | + def fstyle(f: float) -> str: |
| 570 | + return click.style(f"{f:5.2f}", fg="cyan") |
| 571 | + |
| 572 | + def tstyle(s: str) -> str: |
| 573 | + return s |
| 574 | + |
| 575 | + def symstyle(s: str) -> str: |
| 576 | + return click.style(s, fg="white", bold=True) |
| 577 | + |
| 578 | + # fmt: off |
| 579 | + click.echo( |
| 580 | + green(f"{global_step:>{digits}}") + symstyle("/") + green(f"{cfg.total_timesteps} ") |
| 581 | + + f"{symstyle('|')} {tstyle('meanrew')} {estyle(mean_reward)} " |
| 582 | + + f"{symstyle('|')} {tstyle('explained_var')} {fstyle(explained_var.item())} " |
| 583 | + + f"{symstyle('|')} {tstyle('entropy')} {fstyle(entropy_loss.item())} " |
| 584 | + + f"{symstyle('|')} {tstyle('episodic_reward')} {estyle(episodic_reward)} " |
| 585 | + + f"{symstyle('|')} {tstyle('episode_length')} {estyle(episode_length)} " |
| 586 | + + f"{symstyle('|')} {tstyle('episodes')} {green(str(episode_count))} " |
| 587 | + + f"{symstyle('|')} {tstyle('fps')} {green(str(int(fps)))}" |
576 | 588 | ) |
| 589 | + # fmt: on |
577 | 590 | writer.add_scalar( |
578 | 591 | "charts/SPS", |
579 | 592 | int((global_step - initial_step) / (time.time() - start_time)), |
@@ -679,6 +692,8 @@ def initialize(cfg: TrainConfig, ctx: Dict[str, Any]) -> State: |
679 | 692 | value_function=value_function, |
680 | 693 | optimizer=optimizer, |
681 | 694 | vf_optimizer=vf_optimizer, |
| 695 | + obs_space=ctx["obs_space"], |
| 696 | + action_space=ctx["action_space"], |
682 | 697 | ) |
683 | 698 |
|
684 | 699 |
|
|
0 commit comments