|
25 | 25 | from lerobot.scripts.eval import eval_policy
|
26 | 26 |
|
27 | 27 |
|
| 28 | +def make_optimizer_and_scheduler(cfg, policy): |
| 29 | + if cfg.policy.name == "act": |
| 30 | + optimizer_params_dicts = [ |
| 31 | + { |
| 32 | + "params": [ |
| 33 | + p |
| 34 | + for n, p in policy.named_parameters() |
| 35 | + if not n.startswith("backbone") and p.requires_grad |
| 36 | + ] |
| 37 | + }, |
| 38 | + { |
| 39 | + "params": [ |
| 40 | + p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad |
| 41 | + ], |
| 42 | + "lr": cfg.training.lr_backbone, |
| 43 | + }, |
| 44 | + ] |
| 45 | + optimizer = torch.optim.AdamW( |
| 46 | + optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay |
| 47 | + ) |
| 48 | + lr_scheduler = None |
| 49 | + elif cfg.policy.name == "diffusion": |
| 50 | + optimizer = torch.optim.Adam( |
| 51 | + policy.diffusion.parameters(), |
| 52 | + cfg.training.lr, |
| 53 | + cfg.training.adam_betas, |
| 54 | + cfg.training.adam_eps, |
| 55 | + cfg.training.adam_weight_decay, |
| 56 | + ) |
| 57 | + assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." |
| 58 | + lr_scheduler = get_scheduler( |
| 59 | + cfg.training.lr_scheduler, |
| 60 | + optimizer=optimizer, |
| 61 | + num_warmup_steps=cfg.training.lr_warmup_steps, |
| 62 | + num_training_steps=cfg.training.offline_steps, |
| 63 | + ) |
| 64 | + elif policy.name == "tdmpc": |
| 65 | + optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) |
| 66 | + lr_scheduler = None |
| 67 | + else: |
| 68 | + raise NotImplementedError() |
| 69 | + |
| 70 | + return optimizer, lr_scheduler |
| 71 | + |
| 72 | + |
28 | 73 | def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
29 | 74 | start_time = time.time()
|
30 | 75 | policy.train()
|
@@ -276,46 +321,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
276 | 321 |
|
277 | 322 | # Create optimizer and scheduler
|
278 | 323 | # Temporary hack to move optimizer out of policy
|
279 |
| - if cfg.policy.name == "act": |
280 |
| - optimizer_params_dicts = [ |
281 |
| - { |
282 |
| - "params": [ |
283 |
| - p |
284 |
| - for n, p in policy.named_parameters() |
285 |
| - if not n.startswith("backbone") and p.requires_grad |
286 |
| - ] |
287 |
| - }, |
288 |
| - { |
289 |
| - "params": [ |
290 |
| - p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad |
291 |
| - ], |
292 |
| - "lr": cfg.training.lr_backbone, |
293 |
| - }, |
294 |
| - ] |
295 |
| - optimizer = torch.optim.AdamW( |
296 |
| - optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay |
297 |
| - ) |
298 |
| - lr_scheduler = None |
299 |
| - elif cfg.policy.name == "diffusion": |
300 |
| - optimizer = torch.optim.Adam( |
301 |
| - policy.diffusion.parameters(), |
302 |
| - cfg.training.lr, |
303 |
| - cfg.training.adam_betas, |
304 |
| - cfg.training.adam_eps, |
305 |
| - cfg.training.adam_weight_decay, |
306 |
| - ) |
307 |
| - assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." |
308 |
| - lr_scheduler = get_scheduler( |
309 |
| - cfg.training.lr_scheduler, |
310 |
| - optimizer=optimizer, |
311 |
| - num_warmup_steps=cfg.training.lr_warmup_steps, |
312 |
| - num_training_steps=cfg.training.offline_steps, |
313 |
| - ) |
314 |
| - elif policy.name == "tdmpc": |
315 |
| - optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) |
316 |
| - lr_scheduler = None |
317 |
| - else: |
318 |
| - raise NotImplementedError() |
| 324 | + optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) |
319 | 325 |
|
320 | 326 | num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
321 | 327 | num_total_params = sum(p.numel() for p in policy.parameters())
|
|
0 commit comments