Skip to content

Commit 9f841c3

Browse files
committed
WIP
1 parent 8ec75f2 commit 9f841c3

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

download_and_upload_dataset.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
useless dependencies when using datasets.
44
"""
55

6-
from dataclasses import dataclass, field
76
import io
87
import json
98
import pickle
109
import shutil
1110
from pathlib import Path
12-
from typing import Any, ClassVar, Optional
1311

1412
import einops
1513
import h5py
@@ -23,12 +21,10 @@
2321

2422
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
2523

26-
27-
2824
# @dataclass
2925
# class VideoFrame:
3026
# """
31-
27+
3228
# Example:
3329

3430
# ```py
@@ -56,6 +52,7 @@
5652
# def decode_example(self, value):
5753
# return value
5854

55+
5956
def download_and_upload(root, revision, dataset_id):
6057
# TODO(rcadene, adilzouitine): add community_id/user_id (e.g. "lerobot", "cadene") or repo_id (e.g. "lerobot/pusht")
6158
if "pusht" in dataset_id:
@@ -310,7 +307,7 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
310307
data_dict = concatenate_episodes(ep_dicts)
311308

312309
features = {
313-
#"observation.image": Image(),
310+
# "observation.image": Image(),
314311
"observation.image": Value(dtype="int64", id="video"),
315312
"observation.state": Sequence(
316313
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)

lerobot/common/datasets/lerobot_dataset.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def __getitem__(self, idx):
8989
return item
9090

9191

92-
9392
def yuv_to_rgb(frames):
9493
assert frames.dtype == torch.uint8
9594
assert frames.ndim == 4
@@ -135,7 +134,7 @@ def decode_video_frame_torchaudio(video_path, timestamp):
135134
device = "cpu"
136135
width = None
137136
height = None
138-
image_format = "rgb" # or "yuv"
137+
image_format = "rgb" # or "yuv"
139138
frame_rate = None
140139

141140
filter_desc = []
@@ -172,7 +171,7 @@ def decode_video_frame_torchaudio(video_path, timestamp):
172171
if resize_height:
173172
scales.append(f"height={height}")
174173
filter_desc.append(f"scale={':'.join(scales)}")
175-
174+
176175
# choice of format
177176
if image_format is not None:
178177
if device == "cuda":
@@ -196,6 +195,7 @@ def decode_video_frame_torchaudio(video_path, timestamp):
196195
# create a stream and load a certain number of frame at a certain frame rate
197196
# TODO(rcadene): make sure it's the most optimal way to do it
198197
from torchaudio.io import StreamReader
198+
199199
s = StreamReader(str(video_path))
200200
s.seek(timestamp)
201201
s.add_video_stream(**video_stream_kwgs)
@@ -204,12 +204,13 @@ def decode_video_frame_torchaudio(video_path, timestamp):
204204

205205
if "yuv" in image_format:
206206
frames = yuv_to_rgb(frames)
207-
207+
208208
if len(frames) > 1:
209209
return frames
210210

211211
return frames[0]
212212

213+
213214
def decode_video_frames_ffmpegio(video_path, timestamp):
214215
num_contiguous_frames = 1
215216
device = "cpu"
@@ -220,11 +221,13 @@ def decode_video_frames_ffmpegio(video_path, timestamp):
220221
)
221222
frames = torch.from_numpy(frames)
222223
import einops
224+
223225
frames = einops.rearrange(frames, "b h w c -> b c h w")
224226
if device == "cuda":
225227
frames = frames.to(device)
226228
return frames
227229

230+
228231
def _decode_frames_decord(video_path, timestamp):
229232
num_contiguous_frames = 1
230233
device = "cpu"
@@ -239,4 +242,4 @@ def _decode_frames_decord(video_path, timestamp):
239242
# frames = vr.get_batch([frame_id])
240243
# frames = torch.from_numpy(frames.asnumpy())
241244
# frames = einops.rearrange(frames, "b h w c -> b c h w")
242-
# return frames
245+
# return frames

0 commit comments

Comments
 (0)