Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions examples/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from lerobot.datasets.lerobot_dataset import MultiLeRobotDataset

REPO_A = "lerobot/pusht"
REPO_B = "lerobot/aloha_mobile_cabinet" # replace with the actual repo id

feature_keys_mapping = {
REPO_A: { # pusht (1 camera, 2-dim)
"action": "actions",
"observation.state": "obs_state",
"observation.image": "obs_image.cam_high",
},
REPO_B: { # dual arm (3 cameras, 14-dim)
"action": "actions",
"observation.state": "obs_state",
"observation.images.cam_high": "obs_image.cam_high",
"observation.images.cam_left_wrist": "obs_image.cam_left_wrist",
"observation.images.cam_right_wrist": "obs_image.cam_right_wrist",
},
}

from torchvision.transforms.v2 import Compose, ToImage, Resize
image_tf = Compose([
ToImage(), # converts to tensor if needed
Resize((224, 224)), # unify sizes across datasets (96x96 vs 480x640)
])

from torch.utils.data import DataLoader

dataset = MultiLeRobotDataset(
repo_ids=[REPO_A, REPO_B],
image_transforms=image_tf, # ensures same HxW
feature_keys_mapping=feature_keys_mapping,
train_on_all_features=True, # keep union of cameras; zero-fill missing
# optional: override if you want fixed maxima; else inferred:
# max_action_dim=14,
# max_state_dim=14,
max_action_dim=14,
max_state_dim=14,
max_image_dim=224,
ignore_keys=[
"next.*", # drop reward/done/success
"index",
"timestamp",
"videos/*", # drop all video metadata
"observation.effort", # 👈 drop effort everywhere
],
)
breakpoint()
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
for _ in range(100):
batch = next(iter(loader))

breakpoint()
# vectors padded to maxima (pusht:2 -> 14; dual-arm:14 -> 14)
assert batch["actions"].shape[-1] == 14
assert batch["obs_state"].shape[-1] == 14
assert batch["actions_padding_mask"].shape[-1] == 14
assert batch["obs_state_padding_mask"].shape[-1] == 14

# cameras: all canonical keys exist; pusht will have wrists zero-filled
for cam in ["obs_image.cam_high", "obs_image.cam_left_wrist", "obs_image.cam_right_wrist"]:
assert cam in batch
assert f"{cam}_is_pad" in batch
# images should all be 3x224x224 (or your transform’s size)
img = batch[cam]
assert img.ndim in (4, 5) # (B,C,H,W) or (B,T,C,H,W) depending on your loader
16 changes: 16 additions & 0 deletions examples/tester.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# storage / caches
RAID=/raid/jade
export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers
export HF_HOME=$RAID/.cache/huggingface
export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets
export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot
export WANDB_CACHE_DIR=$RAID/.cache/wandb
export TMPDIR=$RAID/.cache/tmp
mkdir -p $TMPDIR
export WANDB_MODE=offline
# export HF_DATASETS_OFFLINE=1
# export HF_HUB_OFFLINE=1
export TOKENIZERS_PARALLELISM=false
export MUJOCO_GL=egl

python examples/tester.py
76 changes: 76 additions & 0 deletions src/lerobot/datasets/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,79 @@ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)

return aggregated_stats

import numpy as np

def aggregate_stats_multi(
stats_list: list[dict[str, dict]],
max_action_dim: int | None = None,
max_state_dim: int | None = None,
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.

Supports heterogeneous robots by padding action/state stats to the max dim.
The final stats will have the union of all data keys from each of the stats dicts.

- new_min = elementwise min across datasets
- new_max = elementwise max across datasets
- new_mean = weighted mean (by count)
- new_std = recomputed from total variance
"""

data_keys = {key for stats in stats_list for key in stats}
aggregated_stats = {key: {} for key in data_keys}

def _pad(arr: np.ndarray, target: int) -> np.ndarray:
if arr.ndim == 0: # scalar
return arr
if target is None or target <= 0 or arr.shape[-1] == target:
return arr
pad_width = [(0, 0)] * arr.ndim
pad_width[-1] = (0, target - arr.shape[-1])
return np.pad(arr, pad_width, mode="constant")

for key in data_keys:
stats_with_key = [stats[key] for stats in stats_list if key in stats]

# decide if this key should be padded
target_dim = None
if "action" in key and max_action_dim:
target_dim = max_action_dim
elif "state" in key and max_state_dim:
target_dim = max_state_dim

padded = []
counts = []
for s in stats_with_key:
mean = _pad(np.array(s["mean"]), target_dim)
std = _pad(np.array(s["std"]), target_dim)
min_ = _pad(np.array(s["min"]), target_dim)
max_ = _pad(np.array(s["max"]), target_dim)
count = s.get("count", 1)

padded.append(dict(mean=mean, std=std, min=min_, max=max_, count=count))
counts.append(count)

counts = np.array(counts, dtype=np.float64)
total_count = counts.sum()

means = np.stack([p["mean"] for p in padded])
stds = np.stack([p["std"] for p in padded])
mins = np.stack([p["min"] for p in padded])
maxs = np.stack([p["max"] for p in padded])

# weighted mean (broadcast weights properly)
new_mean = np.average(means, axis=0, weights=counts)
new_var = np.average(stds**2 + (means - new_mean)**2, axis=0, weights=counts)

new_std = np.sqrt(new_var)

aggregated_stats[key] = {
"min": mins.min(axis=0),
"max": maxs.max(axis=0),
"mean": new_mean,
"std": new_std,
"count": int(total_count),
}

return aggregated_stats
Loading
Loading