Skip to content

Commit 67f1b50

Browse files
committed
rename make_optimizer
1 parent 6ea94ec commit 67f1b50

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

lerobot/scripts/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from lerobot.scripts.eval import eval_policy
2626

2727

28-
def make_optimizer(cfg, policy):
28+
def make_optimizer_and_scheduler(cfg, policy):
2929
if cfg.policy.name == "act":
3030
optimizer_params_dicts = [
3131
{
@@ -321,7 +321,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
321321

322322
# Create optimizer and scheduler
323323
# Temporary hack to move optimizer out of policy
324-
optimizer, lr_scheduler = make_optimizer(cfg, policy)
324+
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
325325

326326
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
327327
num_total_params = sum(p.numel() for p in policy.parameters())

tests/scripts/save_policy_to_safetensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from lerobot.common.datasets.factory import make_dataset
88
from lerobot.common.policies.factory import make_policy
99
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
10-
from lerobot.scripts.train import make_optimizer
10+
from lerobot.scripts.train import make_optimizer_and_scheduler
1111
from tests.utils import DEFAULT_CONFIG_PATH
1212

1313

@@ -25,7 +25,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
2525
dataset = make_dataset(cfg)
2626
policy = make_policy(cfg, dataset_stats=dataset.stats)
2727
policy.train()
28-
optimizer, _ = make_optimizer(cfg, policy)
28+
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
2929

3030
dataloader = torch.utils.data.DataLoader(
3131
dataset,

0 commit comments

Comments
 (0)