-
Notifications
You must be signed in to change notification settings - Fork 2.7k
feat(dataset): Add Multidataset Training support #2008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds Multidataset Training support to the LeRobot codebase, enabling training models across multiple heterogeneous datasets. The implementation allows for feature mapping, automatic padding of observations/actions to common dimensions, and unified handling of image features across different datasets.
Key changes:
- Added comprehensive multi-dataset functionality with feature mapping and padding capabilities
- Extended stats aggregation to handle multiple robot types with different dimensional requirements
- Created example script demonstrating multi-dataset training setup
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.
File | Description |
---|---|
src/lerobot/datasets/lerobot_dataset.py | Major addition of MultiLeRobotDataset classes with feature mapping, padding, and unified dataset handling |
src/lerobot/datasets/compute_stats.py | Extended stats aggregation functionality for multi-dataset scenarios |
examples/tester.sh | Shell script for setting up environment variables for testing |
examples/tester.py | Example implementation demonstrating multi-dataset usage |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
from huggingface_hub import HfApi, snapshot_download | ||
from huggingface_hub.errors import RevisionNotFoundError | ||
|
||
from collections import defaultdict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate import of defaultdict
. This import is already present at line 1439 and should be removed from line 34 to avoid redundancy.
from collections import defaultdict |
Copilot uses AI. Check for mistakes.
from collections import defaultdict | ||
from typing import Callable | ||
import copy | ||
import numpy as np | ||
import torch | ||
import datasets | ||
from pathlib import Path | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These imports should be moved to the top of the file with the other imports rather than being placed in the middle of the code. This violates Python import conventions and makes the code harder to maintain.
from collections import defaultdict | |
from typing import Callable | |
import copy | |
import numpy as np | |
import torch | |
import datasets | |
from pathlib import Path |
Copilot uses AI. Check for mistakes.
try: | ||
from lerobot.common.constants import ( | ||
ACTION, OBS_ENV_STATE, OBS_STATE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3 | ||
) | ||
except Exception: | ||
# Fallbacks if constants are already strings elsewhere | ||
ACTION = "action" | ||
OBS_ENV_STATE = "observation.env_state" | ||
OBS_STATE = "observation.state" | ||
OBS_IMAGE = "observation.image" | ||
OBS_IMAGE_2 = "observation.image_2" | ||
OBS_IMAGE_3 = "observation.image_3" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a bare except Exception:
is too broad and could mask import errors. Consider catching specific exceptions like ImportError
or ModuleNotFoundError
instead.
Copilot uses AI. Check for mistakes.
episodes: dict | None = None, | ||
image_transforms: Callable | None = None, | ||
delta_timestamps: dict[str, list[float]] | None = None, | ||
delta_timestamps: dict[list[float]] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type annotation is incorrect. Should be dict[str, list[float]] | None = None
to match the original signature and usage throughout the code where delta_timestamps is expected to have string keys.
delta_timestamps: dict[list[float]] | None = None, | |
delta_timestamps: dict[str, list[float]] | None = None, |
Copilot uses AI. Check for mistakes.
# with multiple robots of different ranges. Instead we should have one normalization | ||
# per robot. | ||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) | ||
self.delta_timestamps = delta_timestamps.get(repo_id, None) if delta_timestamps else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assignment is problematic because repo_id
is from the loop variable and will always reference the last repo_id from the loop. This should either be moved inside the dataset creation loop or handled differently to avoid incorrect assignment.
self.delta_timestamps = delta_timestamps.get(repo_id, None) if delta_timestamps else None | |
self.delta_timestamps = {rid: delta_timestamps.get(rid, None) for rid in datasets_repo_ids} if delta_timestamps else None |
Copilot uses AI. Check for mistakes.
if "actions" in item and self.max_action_dim is not None: | ||
act = item["actions"] | ||
if act.shape[-1] < self.max_action_dim: | ||
pad_len = self.max_action_dim - act.shape[-1] | ||
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1) | ||
item["actions_padding_mask"] = torch.cat( | ||
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], | ||
dim=-1, | ||
) | ||
|
||
# pad obs_state if too short | ||
if "obs_state" in item and self.max_state_dim is not None: | ||
st = item["obs_state"] | ||
if st.shape[-1] < self.max_state_dim: | ||
pad_len = self.max_state_dim - st.shape[-1] | ||
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1) | ||
item["obs_state_padding_mask"] = torch.cat( | ||
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], | ||
dim=-1, | ||
) | ||
# actions | ||
if "actions" in item and self.max_action_dim is not None: | ||
act = item["actions"] | ||
if act.shape[-1] < self.max_action_dim: | ||
pad_len = self.max_action_dim - act.shape[-1] | ||
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1) | ||
mask = torch.cat( | ||
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], | ||
dim=-1, | ||
) | ||
else: | ||
mask = torch.zeros(self.max_action_dim, dtype=torch.bool) # 👈 all False if no padding | ||
item["actions_padding_mask"] = mask | ||
# obs state | ||
if "obs_state" in item and self.max_state_dim is not None: | ||
st = item["obs_state"] | ||
if st.shape[-1] < self.max_state_dim: | ||
pad_len = self.max_state_dim - st.shape[-1] | ||
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1) | ||
mask = torch.cat( | ||
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], | ||
dim=-1, | ||
) | ||
else: | ||
mask = torch.zeros(self.max_state_dim, dtype=torch.bool) # 👈 always add mask | ||
item["obs_state_padding_mask"] = mask | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate padding logic for both 'actions' and 'obs_state'. The padding code is repeated twice for each feature type, which creates maintenance burden. Consider extracting this into a helper function or removing the duplicate blocks.
if "actions" in item and self.max_action_dim is not None: | |
act = item["actions"] | |
if act.shape[-1] < self.max_action_dim: | |
pad_len = self.max_action_dim - act.shape[-1] | |
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1) | |
item["actions_padding_mask"] = torch.cat( | |
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], | |
dim=-1, | |
) | |
# pad obs_state if too short | |
if "obs_state" in item and self.max_state_dim is not None: | |
st = item["obs_state"] | |
if st.shape[-1] < self.max_state_dim: | |
pad_len = self.max_state_dim - st.shape[-1] | |
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1) | |
item["obs_state_padding_mask"] = torch.cat( | |
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], | |
dim=-1, | |
) | |
# actions | |
if "actions" in item and self.max_action_dim is not None: | |
act = item["actions"] | |
if act.shape[-1] < self.max_action_dim: | |
pad_len = self.max_action_dim - act.shape[-1] | |
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1) | |
mask = torch.cat( | |
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], | |
dim=-1, | |
) | |
else: | |
mask = torch.zeros(self.max_action_dim, dtype=torch.bool) # 👈 all False if no padding | |
item["actions_padding_mask"] = mask | |
# obs state | |
if "obs_state" in item and self.max_state_dim is not None: | |
st = item["obs_state"] | |
if st.shape[-1] < self.max_state_dim: | |
pad_len = self.max_state_dim - st.shape[-1] | |
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1) | |
mask = torch.cat( | |
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], | |
dim=-1, | |
) | |
else: | |
mask = torch.zeros(self.max_state_dim, dtype=torch.bool) # 👈 always add mask | |
item["obs_state_padding_mask"] = mask | |
# Pad actions and obs_state features if needed | |
self._pad_feature(item, "actions", self.max_action_dim) | |
self._pad_feature(item, "obs_state", self.max_state_dim) |
Copilot uses AI. Check for mistakes.
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets] | ||
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove] | ||
print( | ||
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}" | ||
) | ||
return filtered_datasets, datasets_maks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable name has a typo: 'datasets_maks' should be 'datasets_mask'.
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets] | |
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove] | |
print( | |
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}" | |
) | |
return filtered_datasets, datasets_maks | |
datasets_mask = [ds not in datasets_to_remove for ds in ls_datasets] | |
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove] | |
print( | |
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}" | |
) | |
return filtered_datasets, datasets_mask |
Copilot uses AI. Check for mistakes.
def _should_ignore(self, key: str) -> bool: | ||
# exact or glob-style match | ||
for pat in self._ignore_patterns: | ||
if key == pat or fnmatch.fnmatch(key, pat): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing import for fnmatch
module. The fnmatch.fnmatch
function is used but fnmatch
is not imported, which will cause a NameError
at runtime.
Copilot uses AI. Check for mistakes.
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict: | ||
"""Reshape features to have a maximum dimension of `max_dim`.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using mutable default argument {}
for keys_to_max_dim
parameter. This is a common Python anti-pattern that can lead to unexpected behavior. Use None
as default and create a new dict inside the function if needed.
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict: | |
"""Reshape features to have a maximum dimension of `max_dim`.""" | |
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = None) -> dict: | |
"""Reshape features to have a maximum dimension of `max_dim`.""" | |
if keys_to_max_dim is None: | |
keys_to_max_dim = {} |
Copilot uses AI. Check for mistakes.
What this does
Add Multidataset Training support
TBD