Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,13 @@ Secondly, assuming you have trained a policy, you need:

- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).

To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):

```
to_upload
├── config.yaml
├── model.pt
└── stats.pth
└── model.pt
```

With the folder prepared, run the following with a desired revision ID.
Expand Down
2 changes: 1 addition & 1 deletion examples/1_load_hugging_face_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
# TODO(rcadene): list available datasets on lerobot page using `datasets`

# download/load hugging face dataset in pyarrow format
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10

# display name of dataset and its features
# TODO(rcadene): update to make the print pretty
Expand Down
2 changes: 0 additions & 2 deletions examples/3_evaluate_pretrained_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

config_path = folder / "config.yaml"
weights_path = folder / "model.pt"
stats_path = folder / "stats.pth" # normalization stats

# Override some config parameters to do with evaluation.
overrides = [
Expand All @@ -36,5 +35,4 @@
eval(
cfg,
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)
5 changes: 2 additions & 3 deletions examples/4_train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
policy.train()
policy.to(device)

Expand Down Expand Up @@ -62,7 +62,6 @@
done = True
break

# Save the policy, configuration, and normalization stats for later use.
# Save the policy and configuration for later use.
policy.save(output_directory / "model.pt")
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")
57 changes: 10 additions & 47 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@
from pathlib import Path

import torch
from torchvision.transforms import v2

from lerobot.common.transforms import NormalizeTransform
from omegaconf import OmegaConf

DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None


def make_dataset(
cfg,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True,
stats_path=None,
split="train",
):
if cfg.env.name == "xarm":
Expand All @@ -33,58 +28,26 @@ def make_dataset(
else:
raise ValueError(cfg.env.name)

transforms = None
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"

if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation.state"] = {}
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load a first dataset to access precomputed stats
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
)
stats = stats_dataset.stats
else:
stats = torch.load(stats_path)

transforms = v2.Compose(
[
NormalizeTransform(
stats,
in_keys=[
"observation.state",
"action",
],
mode=normalization_mode,
),
]
)

delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key])

# TODO(rcadene): add data augmentations

dataset = clsfunc(
dataset_id=cfg.dataset_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,
transform=transforms,
)

if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)

return dataset
18 changes: 4 additions & 14 deletions lerobot/common/envs/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import einops
import torch

from lerobot.common.transforms import apply_inverse_transform


def preprocess_observation(observation, transform=None):
def preprocess_observation(observation):
# map to expected inputs for the policy
obs = {}

Expand All @@ -24,7 +22,7 @@ def preprocess_observation(observation, transform=None):
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"

# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w")
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255

Expand All @@ -33,19 +31,11 @@ def preprocess_observation(observation, transform=None):
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()

# apply same transforms as in training
if transform is not None:
for key in obs:
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])

return obs


def postprocess_action(action, transform=None):
action = action.to("cpu")
# action is a batch (num_env,action_dim) instead of an item (action_dim),
# we assume applying inverse transform on a batch works the same
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
def postprocess_action(action):
action = action.to("cpu").numpy()
assert (
action.ndim == 2
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
Expand Down
68 changes: 46 additions & 22 deletions lerobot/common/policies/act/configuration_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,30 @@ class ActionChunkingTransformerConfig:
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".

The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `state_dim`, `action_dim` and `camera_names`.
Those are: `input_shapes` and 'output_shapes`.

Args:
state_dim: Dimensionality of the observation state space (excluding images).
action_dim: Dimensionality of the action space.
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
camera_names: The (unique) set of names for the cameras.
chunk_size: The size of the action prediction "chunks" in units of environment steps.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.images.top" refers to an input from the
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables
modes are "mean_std" which substracts the mean and divide by the standard
deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
Expand All @@ -50,21 +57,35 @@ class ActionChunkingTransformerConfig:
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
"""

# Environment.
state_dim: int = 14
action_dim: int = 14

# Inputs / output structure.
# Input / output structure.
n_obs_steps: int = 1
camera_names: tuple[str] = ("top",)
chunk_size: int = 100
n_action_steps: int = 100

# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = field(
default_factory=lambda: [0.485, 0.456, 0.406]
input_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [14],
}
)

# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
)
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])

# Architecture.
# Vision backbone.
Expand Down Expand Up @@ -117,7 +138,10 @@ def __post_init__(self):
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.camera_names != ["top"]:
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
if len(set(self.camera_names)) != len(self.camera_names):
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
if (
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
or "observation.images.top" not in self.input_shapes
):
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')
Loading