Skip to content

Commit c77633c

Browse files
authored
Add regression tests (#119)
- Add `tests/scripts/save_policy_to_safetensor.py` to generate test artifacts - Add `test_backward_compatibility to test generated outputs from the policies against artifacts
1 parent 19812ca commit c77633c

File tree

15 files changed

+236
-43
lines changed

15 files changed

+236
-43
lines changed

lerobot/common/policies/tdmpc/modeling_tdmpc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def __init__(
8080
self.config = config
8181
self.model = TDMPCTOLD(config)
8282
self.model_target = deepcopy(self.model)
83-
self.model_target.eval()
83+
for param in self.model_target.parameters():
84+
param.requires_grad = False
8485

8586
if config.input_normalization_modes is not None:
8687
self.normalize_inputs = Normalize(

lerobot/configs/policy/tdmpc.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# @package _global_
22

33
seed: 1
4+
dataset_repo_id: lerobot/xarm_lift_medium_replay
45

56
training:
67
offline_steps: 25000

lerobot/scripts/train.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,51 @@
2525
from lerobot.scripts.eval import eval_policy
2626

2727

28+
def make_optimizer_and_scheduler(cfg, policy):
29+
if cfg.policy.name == "act":
30+
optimizer_params_dicts = [
31+
{
32+
"params": [
33+
p
34+
for n, p in policy.named_parameters()
35+
if not n.startswith("backbone") and p.requires_grad
36+
]
37+
},
38+
{
39+
"params": [
40+
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
41+
],
42+
"lr": cfg.training.lr_backbone,
43+
},
44+
]
45+
optimizer = torch.optim.AdamW(
46+
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
47+
)
48+
lr_scheduler = None
49+
elif cfg.policy.name == "diffusion":
50+
optimizer = torch.optim.Adam(
51+
policy.diffusion.parameters(),
52+
cfg.training.lr,
53+
cfg.training.adam_betas,
54+
cfg.training.adam_eps,
55+
cfg.training.adam_weight_decay,
56+
)
57+
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
58+
lr_scheduler = get_scheduler(
59+
cfg.training.lr_scheduler,
60+
optimizer=optimizer,
61+
num_warmup_steps=cfg.training.lr_warmup_steps,
62+
num_training_steps=cfg.training.offline_steps,
63+
)
64+
elif policy.name == "tdmpc":
65+
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
66+
lr_scheduler = None
67+
else:
68+
raise NotImplementedError()
69+
70+
return optimizer, lr_scheduler
71+
72+
2873
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
2974
start_time = time.time()
3075
policy.train()
@@ -276,46 +321,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
276321

277322
# Create optimizer and scheduler
278323
# Temporary hack to move optimizer out of policy
279-
if cfg.policy.name == "act":
280-
optimizer_params_dicts = [
281-
{
282-
"params": [
283-
p
284-
for n, p in policy.named_parameters()
285-
if not n.startswith("backbone") and p.requires_grad
286-
]
287-
},
288-
{
289-
"params": [
290-
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
291-
],
292-
"lr": cfg.training.lr_backbone,
293-
},
294-
]
295-
optimizer = torch.optim.AdamW(
296-
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
297-
)
298-
lr_scheduler = None
299-
elif cfg.policy.name == "diffusion":
300-
optimizer = torch.optim.Adam(
301-
policy.diffusion.parameters(),
302-
cfg.training.lr,
303-
cfg.training.adam_betas,
304-
cfg.training.adam_eps,
305-
cfg.training.adam_weight_decay,
306-
)
307-
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
308-
lr_scheduler = get_scheduler(
309-
cfg.training.lr_scheduler,
310-
optimizer=optimizer,
311-
num_warmup_steps=cfg.training.lr_warmup_steps,
312-
num_training_steps=cfg.training.offline_steps,
313-
)
314-
elif policy.name == "tdmpc":
315-
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
316-
lr_scheduler = None
317-
else:
318-
raise NotImplementedError()
324+
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
319325

320326
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
321327
num_total_params = sum(p.numel() for p in policy.parameters())
4.98 KB
Binary file not shown.
30.9 KB
Binary file not shown.
196 Bytes
Binary file not shown.
32.6 KB
Binary file not shown.
4.49 KB
Binary file not shown.
46.3 KB
Binary file not shown.
68 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)