Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,6 @@ class DiffusionConfig:
# Inference
num_inference_steps: int | None = None

# ---
# TODO(alexander-soare): Remove these from the policy config.
use_ema: bool = True
ema_update_after_step: int = 0
ema_min_alpha: float = 0.0
ema_max_alpha: float = 0.9999
ema_inv_gamma: float = 1.0
ema_power: float = 0.75

def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
Expand Down
84 changes: 1 addition & 83 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@
TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Move EMA out of policy.
- Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy.
- One more pass on comments and documentation.
"""

import copy
import math
from collections import deque
from typing import Callable
Expand All @@ -21,7 +17,6 @@
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm

from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
Expand Down Expand Up @@ -71,13 +66,6 @@ def __init__(

self.diffusion = DiffusionModel(config)

# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.config.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = DiffusionEMA(config, model=self.ema_diffusion)

def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
Expand Down Expand Up @@ -109,9 +97,6 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.

Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model
weights.
Comment on lines -112 to -114
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a note about EMA, saying that we tested with and without, and got as good or better results without EMA, so we decided to remove it for sake of simplicity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay but I added them in the yaml config as this detail is more relevant to the outer scope. Ptal

"""
assert "observation.image" in batch
assert "observation.state" in batch
Expand All @@ -123,10 +108,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
if not self.training and self.ema_diffusion is not None:
actions = self.ema_diffusion.generate_actions(batch)
else:
actions = self.diffusion.generate_actions(batch)
actions = self.diffusion.generate_actions(batch)

# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
Expand Down Expand Up @@ -612,67 +594,3 @@ def forward(self, x: Tensor, cond: Tensor) -> Tensor:
out = self.conv2(out)
out = out + self.residual_conv(x)
return out


class DiffusionEMA:
"""
Exponential Moving Average of models weights
"""

def __init__(self, config: DiffusionConfig, model: nn.Module):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models
you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999
at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999
at 10K steps, 0.9999 at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_alpha (float): The minimum EMA decay rate. Default: 0.
"""

self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)

self.update_after_step = config.ema_update_after_step
self.inv_gamma = config.ema_inv_gamma
self.power = config.ema_power
self.min_alpha = config.ema_min_alpha
self.max_alpha = config.ema_max_alpha

self.alpha = 0.0
self.optimization_step = 0

def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma) ** -self.power

if step <= 0:
return 0.0

return max(self.min_alpha, min(value, self.max_alpha))

@torch.no_grad()
def step(self, new_model):
self.alpha = self.get_decay(self.optimization_step)

for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
# Iterate over immediate parameters only.
for param, ema_param in zip(
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
):
if isinstance(param, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, _BatchNorm) or not param.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.alpha)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)

self.optimization_step += 1
13 changes: 4 additions & 9 deletions lerobot/configs/policy/diffusion.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# @package _global_

# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
# https://github.com/huggingface/lerobot/pull/134 for more details.

seed: 100000
dataset_repo_id: lerobot/pusht

Expand Down Expand Up @@ -91,12 +95,3 @@ policy:

# Inference
num_inference_steps: 100

# ---
# TODO(alexander-soare): Remove these from the policy config.
use_ema: true
ema_update_after_step: 0
ema_min_alpha: 0.0
ema_max_alpha: 0.9999
ema_inv_gamma: 1.0
ema_power: 0.75
2 changes: 1 addition & 1 deletion lerobot/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def rollout(
max_steps = env.call("_max_episode_steps")[0]
progbar = trange(
max_steps,
desc=f"Running rollout with {max_steps} steps (maximum) per rollout",
desc=f"Running rollout with at most {max_steps} steps",
disable=not enable_progbar,
leave=False,
)
Expand Down
3 changes: 0 additions & 3 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
if lr_scheduler is not None:
lr_scheduler.step()

if hasattr(policy, "ema") and policy.ema is not None:
policy.ema.step(policy.diffusion)

if isinstance(policy, PolicyWithUpdate):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
Expand Down
4 changes: 1 addition & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file not shown.
Binary file not shown.
Binary file not shown.
12 changes: 3 additions & 9 deletions tests/scripts/save_policy_to_safetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,8 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override


if __name__ == "__main__":
env_policies = [
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
]
# Instructions: include the policies that you want to save artifacts for here. Please make sure to revert
# your changes when you are done.
env_policies = []
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
11 changes: 11 additions & 0 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,17 @@ def test_normalize(insert_temporal_dim):
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
def test_backward_compatibility(env_name, policy_name, extra_overrides):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for.
3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
4. Check that this test now passes.
5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`.
"""
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
Expand Down