Skip to content

Commit b09c57f

Browse files
author
Vincent Moens
committed
[Algorithm] Async SAC
ghstack-source-id: 0603426 Pull-Request-resolved: #2946
1 parent d14c227 commit b09c57f

File tree

7 files changed

+518
-56
lines changed

7 files changed

+518
-56
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# environment and task
2+
env:
3+
name: HalfCheetah-v4
4+
task: ""
5+
library: gymnasium
6+
max_episode_steps: 1000
7+
seed: 42
8+
9+
# collector
10+
collector:
11+
total_frames: -1
12+
init_random_frames: 25000
13+
frames_per_batch: 8000
14+
init_env_steps: 1000
15+
device: cuda:1
16+
env_per_collector: 8
17+
reset_at_each_iter: False
18+
update_freq: 10_000
19+
20+
# replay buffer
21+
replay_buffer:
22+
size: 1000000
23+
prb: 0 # use prioritized experience replay
24+
scratch_dir:
25+
26+
# optim
27+
optim:
28+
utd_ratio: 1.0
29+
gamma: 0.99
30+
loss_function: l2
31+
lr: 3.0e-4
32+
weight_decay: 0.0
33+
batch_size: 256
34+
target_update_polyak: 0.995
35+
alpha_init: 1.0
36+
adam_eps: 1.0e-8
37+
38+
# network
39+
network:
40+
hidden_sizes: [256, 256]
41+
activation: relu
42+
default_policy_scale: 1.0
43+
scale_lb: 0.1
44+
device:
45+
46+
# logging
47+
logger:
48+
backend: wandb
49+
project_name: torchrl_example_sac
50+
group_name: null
51+
exp_name: ${env.name}_SAC
52+
mode: online
53+
log_freq: 25000 # logging freq in updates
54+
video: False
55+
56+
compile:
57+
compile: False
58+
compile_mode:
59+
cudagraphs: False

sota-implementations/sac/sac-async.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""SAC Example.
6+
7+
This is a simple self-contained example of a SAC training script.
8+
9+
It supports state environments like MuJoCo.
10+
11+
The helper functions are coded in the utils.py associated with this script.
12+
"""
13+
from __future__ import annotations
14+
15+
import time
16+
17+
import warnings
18+
from functools import partial
19+
20+
import hydra
21+
import numpy as np
22+
import tensordict
23+
import torch
24+
import torch.cuda
25+
import tqdm
26+
from tensordict.nn import CudaGraphModule
27+
from torchrl._utils import compile_with_warmup, timeit
28+
from torchrl.envs.utils import ExplorationType, set_exploration_type
29+
from torchrl.objectives import group_optimizers
30+
from torchrl.record.loggers import generate_exp_name, get_logger
31+
from utils import (
32+
dump_video,
33+
log_metrics,
34+
make_collector_async,
35+
make_environment,
36+
make_loss_module,
37+
make_replay_buffer,
38+
make_sac_agent,
39+
make_sac_optimizer,
40+
make_train_environment,
41+
)
42+
43+
torch.set_float32_matmul_precision("high")
44+
tensordict.nn.functional_modules._exclude_td_from_pytree().set()
45+
46+
47+
@hydra.main(version_base="1.1", config_path="", config_name="config-async")
48+
def main(cfg: DictConfig): # noqa: F821
49+
device = cfg.network.device
50+
if device in ("", None):
51+
if torch.cuda.is_available():
52+
device = torch.device("cuda:0")
53+
else:
54+
device = torch.device("cpu")
55+
device = torch.device(device)
56+
57+
# Create logger
58+
exp_name = generate_exp_name("SAC", cfg.logger.exp_name)
59+
logger = None
60+
if cfg.logger.backend:
61+
logger = get_logger(
62+
logger_type=cfg.logger.backend,
63+
logger_name="sac_logging",
64+
experiment_name=exp_name,
65+
wandb_kwargs={
66+
"mode": cfg.logger.mode,
67+
"config": dict(cfg),
68+
"project": cfg.logger.project_name,
69+
"group": cfg.logger.group_name,
70+
},
71+
)
72+
73+
torch.manual_seed(cfg.env.seed)
74+
np.random.seed(cfg.env.seed)
75+
76+
# Create environments
77+
_, eval_env = make_environment(cfg, logger=logger)
78+
79+
# Create agent
80+
model, _ = make_sac_agent(
81+
cfg, make_train_environment(cfg), eval_env, device
82+
)
83+
_, exploration_policy = make_sac_agent(
84+
cfg, make_train_environment(cfg), eval_env, "cuda:1"
85+
)
86+
exploration_policy.load_state_dict(model[0].state_dict())
87+
88+
# Create SAC loss
89+
loss_module, target_net_updater = make_loss_module(cfg, model)
90+
91+
compile_mode = None
92+
if cfg.compile.compile:
93+
compile_mode = cfg.compile.compile_mode
94+
if compile_mode in ("", None):
95+
if cfg.compile.cudagraphs:
96+
compile_mode = "default"
97+
else:
98+
compile_mode = "reduce-overhead"
99+
compile_mode_collector = compile_mode # "reduce-overhead"
100+
101+
# Create replay buffer
102+
replay_buffer = make_replay_buffer(
103+
batch_size=cfg.optim.batch_size,
104+
prb=cfg.replay_buffer.prb,
105+
buffer_size=cfg.replay_buffer.size,
106+
scratch_dir=cfg.replay_buffer.scratch_dir,
107+
device=device,
108+
shared=True,
109+
prefetch=0,
110+
)
111+
112+
# TODO: Simplify this - ideally we'd like to share the uninitialized lazy tensor storage and fetch it once
113+
# it's initialized
114+
replay_buffer.extend(make_train_environment(cfg).rollout(1).view(-1))
115+
replay_buffer.empty()
116+
117+
# Create off-policy collector
118+
collector = make_collector_async(
119+
cfg,
120+
partial(make_train_environment, cfg),
121+
exploration_policy,
122+
compile_mode=compile_mode_collector,
123+
replay_buffer=replay_buffer,
124+
)
125+
126+
# Create optimizers
127+
(
128+
optimizer_actor,
129+
optimizer_critic,
130+
optimizer_alpha,
131+
) = make_sac_optimizer(cfg, loss_module)
132+
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
133+
del optimizer_actor, optimizer_critic, optimizer_alpha
134+
135+
def update(sampled_tensordict):
136+
# Compute loss
137+
loss_td = loss_module(sampled_tensordict)
138+
139+
actor_loss = loss_td["loss_actor"]
140+
q_loss = loss_td["loss_qvalue"]
141+
alpha_loss = loss_td["loss_alpha"]
142+
143+
(actor_loss + q_loss + alpha_loss).sum().backward()
144+
optimizer.step()
145+
146+
# Update qnet_target params
147+
target_net_updater.step()
148+
149+
optimizer.zero_grad(set_to_none=True)
150+
return loss_td.detach()
151+
152+
if cfg.compile.compile:
153+
update = compile_with_warmup(update, mode=compile_mode, warmup=2)
154+
155+
cudagraphs = cfg.compile.cudagraphs
156+
if cfg.compile.cudagraphs:
157+
warnings.warn(
158+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
159+
category=UserWarning,
160+
)
161+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10)
162+
163+
# Main loop
164+
collected_frames = 0
165+
166+
init_random_frames = cfg.collector.init_random_frames
167+
168+
prb = cfg.replay_buffer.prb
169+
update_freq = cfg.collector.update_freq
170+
171+
eval_rollout_steps = cfg.env.max_episode_steps
172+
log_freq = cfg.logger.log_freq
173+
# TODO: customize this
174+
num_updates = 1000
175+
total_iter = 1_000_000
176+
pbar = tqdm.tqdm(total=total_iter * num_updates)
177+
178+
while replay_buffer.write_count <= init_random_frames:
179+
time.sleep(0.01)
180+
181+
losses = []
182+
for i in range(total_iter * num_updates):
183+
timeit.printevery(num_prints=1000, total_count=total_iter, erase=True)
184+
185+
if i % update_freq == update_freq - 1:
186+
# Update weights of the inference policy
187+
collector.update_policy_weights_()
188+
189+
pbar.update(1)
190+
191+
collected_frames = replay_buffer.write_count
192+
193+
# Optimization steps
194+
with timeit("train"):
195+
with timeit("rb - sample"):
196+
# Sample from replay buffer
197+
sampled_tensordict = replay_buffer.sample()
198+
199+
with timeit("update"):
200+
torch.compiler.cudagraph_mark_step_begin()
201+
# After a certain number of warmup steps, CudaGraphModule will register the graph
202+
# We want to pause the collector while this is happening
203+
if cudagraphs and update.counter == (update._warmup - 1):
204+
with collector.pause():
205+
loss_td = update(sampled_tensordict).clone()
206+
else:
207+
loss_td = update(sampled_tensordict).clone()
208+
losses.append(loss_td.select("loss_actor", "loss_qvalue", "loss_alpha"))
209+
210+
# Update priority
211+
if prb:
212+
replay_buffer.update_priority(sampled_tensordict)
213+
214+
# Logging
215+
if (i % log_freq) == (log_freq - 1):
216+
metrics_to_log = {}
217+
if collected_frames >= init_random_frames:
218+
losses_m = torch.stack(losses).mean()
219+
losses = []
220+
metrics_to_log["train/q_loss"] = losses_m.get("loss_qvalue")
221+
metrics_to_log["train/actor_loss"] = losses_m.get("loss_actor")
222+
metrics_to_log["train/alpha_loss"] = losses_m.get("loss_alpha")
223+
metrics_to_log["train/alpha"] = loss_td["alpha"]
224+
metrics_to_log["train/entropy"] = loss_td["entropy"]
225+
metrics_to_log["train/collected_frames"] = int(
226+
replay_buffer.write_count
227+
)
228+
# Log rewards in the buffer
229+
230+
# Evaluation
231+
with set_exploration_type(
232+
ExplorationType.DETERMINISTIC
233+
), torch.no_grad(), timeit("eval"):
234+
eval_rollout = eval_env.rollout(
235+
eval_rollout_steps,
236+
model[0],
237+
auto_cast_to_device=True,
238+
break_when_any_done=True,
239+
)
240+
eval_env.apply(dump_video)
241+
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
242+
metrics_to_log["eval/reward"] = eval_reward
243+
if logger is not None:
244+
metrics_to_log.update(timeit.todict(prefix="time"))
245+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
246+
log_metrics(logger, metrics_to_log, collected_frames)
247+
248+
collector.shutdown()
249+
if not eval_env.is_closed:
250+
eval_env.close()
251+
252+
253+
if __name__ == "__main__":
254+
main()

0 commit comments

Comments
 (0)