Skip to content

Commit 678e96a

Browse files
author
Vincent Moens
committed
[Algorithm] Async SAC
ghstack-source-id: 23869ee Pull-Request-resolved: #2946
1 parent 604cd7e commit 678e96a

File tree

6 files changed

+425
-30
lines changed

6 files changed

+425
-30
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# environment and task
2+
env:
3+
name: HalfCheetah-v4
4+
task: ""
5+
library: gymnasium
6+
max_episode_steps: 1000
7+
seed: 42
8+
9+
# collector
10+
collector:
11+
total_frames: -1
12+
init_random_frames: 0
13+
frames_per_batch: 8000
14+
init_env_steps: 1000
15+
device: cuda:1
16+
env_per_collector: 8
17+
reset_at_each_iter: False
18+
update_freq: 10_000
19+
20+
# replay buffer
21+
replay_buffer:
22+
size: 1000000
23+
prb: 0 # use prioritized experience replay
24+
scratch_dir:
25+
26+
# optim
27+
optim:
28+
utd_ratio: 1.0
29+
gamma: 0.99
30+
loss_function: l2
31+
lr: 3.0e-4
32+
weight_decay: 0.0
33+
batch_size: 256
34+
target_update_polyak: 0.995
35+
alpha_init: 1.0
36+
adam_eps: 1.0e-8
37+
38+
# network
39+
network:
40+
hidden_sizes: [256, 256]
41+
activation: relu
42+
default_policy_scale: 1.0
43+
scale_lb: 0.1
44+
device:
45+
46+
# logging
47+
logger:
48+
backend: wandb
49+
project_name: torchrl_example_sac
50+
group_name: null
51+
exp_name: ${env.name}_SAC
52+
mode: online
53+
eval_iter: 25000
54+
video: False
55+
56+
compile:
57+
compile: False
58+
compile_mode:
59+
cudagraphs: False

sota-implementations/sac/sac-async.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""SAC Example.
6+
7+
This is a simple self-contained example of a SAC training script.
8+
9+
It supports state environments like MuJoCo.
10+
11+
The helper functions are coded in the utils.py associated with this script.
12+
"""
13+
from __future__ import annotations
14+
15+
import time
16+
17+
import warnings
18+
from functools import partial
19+
20+
import hydra
21+
import numpy as np
22+
import torch
23+
import torch.cuda
24+
import tqdm
25+
from tensordict import TensorDict
26+
from tensordict.nn import CudaGraphModule
27+
from torchrl._utils import compile_with_warmup, timeit
28+
from torchrl.envs.utils import ExplorationType, set_exploration_type
29+
from torchrl.objectives import group_optimizers
30+
from torchrl.record.loggers import generate_exp_name, get_logger
31+
from utils import (
32+
dump_video,
33+
log_metrics,
34+
make_collector_async,
35+
make_environment,
36+
make_loss_module,
37+
make_replay_buffer,
38+
make_sac_agent,
39+
make_sac_optimizer,
40+
make_train_environment,
41+
)
42+
43+
torch.set_float32_matmul_precision("high")
44+
45+
46+
@hydra.main(version_base="1.1", config_path="", config_name="config-async")
47+
def main(cfg: DictConfig): # noqa: F821
48+
device = cfg.network.device
49+
if device in ("", None):
50+
if torch.cuda.is_available():
51+
device = torch.device("cuda:0")
52+
else:
53+
device = torch.device("cpu")
54+
device = torch.device(device)
55+
56+
# Create logger
57+
exp_name = generate_exp_name("SAC", cfg.logger.exp_name)
58+
logger = None
59+
if cfg.logger.backend:
60+
logger = get_logger(
61+
logger_type=cfg.logger.backend,
62+
logger_name="sac_logging",
63+
experiment_name=exp_name,
64+
wandb_kwargs={
65+
"mode": cfg.logger.mode,
66+
"config": dict(cfg),
67+
"project": cfg.logger.project_name,
68+
"group": cfg.logger.group_name,
69+
},
70+
)
71+
72+
torch.manual_seed(cfg.env.seed)
73+
np.random.seed(cfg.env.seed)
74+
75+
# Create environments
76+
_, eval_env = make_environment(cfg, logger=logger)
77+
78+
# Create agent
79+
model, exploration_policy = make_sac_agent(
80+
cfg, make_train_environment(cfg), eval_env, device
81+
)
82+
83+
# Create SAC loss
84+
loss_module, target_net_updater = make_loss_module(cfg, model)
85+
86+
compile_mode = None
87+
if cfg.compile.compile:
88+
compile_mode = cfg.compile.compile_mode
89+
if compile_mode in ("", None):
90+
if cfg.compile.cudagraphs:
91+
compile_mode = "default"
92+
else:
93+
compile_mode = "reduce-overhead"
94+
compile_mode_collector = "reduce-overhead"
95+
96+
# Create replay buffer
97+
replay_buffer = make_replay_buffer(
98+
batch_size=cfg.optim.batch_size,
99+
prb=cfg.replay_buffer.prb,
100+
buffer_size=cfg.replay_buffer.size,
101+
scratch_dir=cfg.replay_buffer.scratch_dir,
102+
device=device,
103+
shared=True,
104+
prefetch=0,
105+
)
106+
107+
# TODO: Simplify this - ideally we'd like to share the uninitialized lazy tensor storage and fetch it once
108+
# it's initialized
109+
replay_buffer.extend(make_train_environment(cfg).rollout(1).view(-1))
110+
replay_buffer.empty()
111+
112+
# Create off-policy collector
113+
collector = make_collector_async(
114+
cfg,
115+
partial(make_train_environment, cfg),
116+
exploration_policy,
117+
compile_mode=compile_mode_collector,
118+
replay_buffer=replay_buffer,
119+
)
120+
121+
# Create optimizers
122+
(
123+
optimizer_actor,
124+
optimizer_critic,
125+
optimizer_alpha,
126+
) = make_sac_optimizer(cfg, loss_module)
127+
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
128+
del optimizer_actor, optimizer_critic, optimizer_alpha
129+
130+
def update(sampled_tensordict):
131+
# Compute loss
132+
loss_td = loss_module(sampled_tensordict)
133+
134+
actor_loss = loss_td["loss_actor"]
135+
q_loss = loss_td["loss_qvalue"]
136+
alpha_loss = loss_td["loss_alpha"]
137+
138+
(actor_loss + q_loss + alpha_loss).sum().backward()
139+
optimizer.step()
140+
141+
# Update qnet_target params
142+
target_net_updater.step()
143+
144+
optimizer.zero_grad(set_to_none=True)
145+
return loss_td.detach()
146+
147+
if cfg.compile.compile:
148+
update = compile_with_warmup(update, mode=compile_mode, warmup=2)
149+
150+
if cfg.compile.cudagraphs:
151+
warnings.warn(
152+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
153+
category=UserWarning,
154+
)
155+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10)
156+
157+
# Main loop
158+
collected_frames = 0
159+
160+
init_random_frames = cfg.collector.init_random_frames
161+
assert init_random_frames == 0
162+
163+
prb = cfg.replay_buffer.prb
164+
update_freq = cfg.collector.update_freq
165+
166+
eval_rollout_steps = cfg.env.max_episode_steps
167+
# TODO: customize this
168+
num_updates = 1000
169+
total_iter = 1_000_000
170+
pbar = tqdm.tqdm(total=total_iter * num_updates)
171+
172+
while not replay_buffer.write_count:
173+
time.sleep(0.01)
174+
175+
losses = TensorDict(batch_size=[num_updates])
176+
for i in range(total_iter * num_updates):
177+
timeit.printevery(num_prints=1000, total_count=total_iter, erase=True)
178+
179+
if i % update_freq == update_freq - 1:
180+
# Update weights of the inference policy
181+
collector.update_policy_weights_()
182+
183+
pbar.update(1)
184+
185+
collected_frames = replay_buffer.write_count
186+
187+
# Optimization steps
188+
with timeit("train"):
189+
with timeit("rb - sample"):
190+
# Sample from replay buffer
191+
sampled_tensordict = replay_buffer.sample()
192+
193+
with timeit("update"):
194+
torch.compiler.cudagraph_mark_step_begin()
195+
loss_td = update(sampled_tensordict).clone()
196+
losses[i % num_updates] = loss_td.select(
197+
"loss_actor", "loss_qvalue", "loss_alpha"
198+
)
199+
200+
# Update priority
201+
if prb:
202+
replay_buffer.update_priority(sampled_tensordict)
203+
204+
# Logging
205+
if i % num_updates == num_updates - 1:
206+
metrics_to_log = {}
207+
if collected_frames >= init_random_frames:
208+
losses_m = losses.mean()
209+
metrics_to_log["train/q_loss"] = losses_m.get("loss_qvalue")
210+
metrics_to_log["train/actor_loss"] = losses_m.get("loss_actor")
211+
metrics_to_log["train/alpha_loss"] = losses_m.get("loss_alpha")
212+
metrics_to_log["train/alpha"] = loss_td["alpha"]
213+
metrics_to_log["train/entropy"] = loss_td["entropy"]
214+
metrics_to_log["train/collected_frames"] = int(replay_buffer.write_count)
215+
# Log rewards in the buffer
216+
217+
# Evaluation
218+
with set_exploration_type(
219+
ExplorationType.DETERMINISTIC
220+
), torch.no_grad(), timeit("eval"):
221+
eval_rollout = eval_env.rollout(
222+
eval_rollout_steps,
223+
model[0],
224+
auto_cast_to_device=True,
225+
break_when_any_done=True,
226+
)
227+
eval_env.apply(dump_video)
228+
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
229+
metrics_to_log["eval/reward"] = eval_reward
230+
if logger is not None:
231+
metrics_to_log.update(timeit.todict(prefix="time"))
232+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
233+
log_metrics(logger, metrics_to_log, collected_frames)
234+
235+
collector.shutdown()
236+
if not eval_env.is_closed:
237+
eval_env.close()
238+
239+
240+
if __name__ == "__main__":
241+
main()

0 commit comments

Comments
 (0)