@@ -89,7 +89,6 @@ def __getitem__(self, idx):
89
89
return item
90
90
91
91
92
-
93
92
def yuv_to_rgb (frames ):
94
93
assert frames .dtype == torch .uint8
95
94
assert frames .ndim == 4
@@ -135,7 +134,7 @@ def decode_video_frame_torchaudio(video_path, timestamp):
135
134
device = "cpu"
136
135
width = None
137
136
height = None
138
- image_format = "rgb" # or "yuv"
137
+ image_format = "rgb" # or "yuv"
139
138
frame_rate = None
140
139
141
140
filter_desc = []
@@ -172,7 +171,7 @@ def decode_video_frame_torchaudio(video_path, timestamp):
172
171
if resize_height :
173
172
scales .append (f"height={ height } " )
174
173
filter_desc .append (f"scale={ ':' .join (scales )} " )
175
-
174
+
176
175
# choice of format
177
176
if image_format is not None :
178
177
if device == "cuda" :
@@ -196,6 +195,7 @@ def decode_video_frame_torchaudio(video_path, timestamp):
196
195
# create a stream and load a certain number of frame at a certain frame rate
197
196
# TODO(rcadene): make sure it's the most optimal way to do it
198
197
from torchaudio .io import StreamReader
198
+
199
199
s = StreamReader (str (video_path ))
200
200
s .seek (timestamp )
201
201
s .add_video_stream (** video_stream_kwgs )
@@ -204,12 +204,13 @@ def decode_video_frame_torchaudio(video_path, timestamp):
204
204
205
205
if "yuv" in image_format :
206
206
frames = yuv_to_rgb (frames )
207
-
207
+
208
208
if len (frames ) > 1 :
209
209
return frames
210
210
211
211
return frames [0 ]
212
212
213
+
213
214
def decode_video_frames_ffmpegio (video_path , timestamp ):
214
215
num_contiguous_frames = 1
215
216
device = "cpu"
@@ -220,11 +221,13 @@ def decode_video_frames_ffmpegio(video_path, timestamp):
220
221
)
221
222
frames = torch .from_numpy (frames )
222
223
import einops
224
+
223
225
frames = einops .rearrange (frames , "b h w c -> b c h w" )
224
226
if device == "cuda" :
225
227
frames = frames .to (device )
226
228
return frames
227
229
230
+
228
231
def _decode_frames_decord (video_path , timestamp ):
229
232
num_contiguous_frames = 1
230
233
device = "cpu"
@@ -239,4 +242,4 @@ def _decode_frames_decord(video_path, timestamp):
239
242
# frames = vr.get_batch([frame_id])
240
243
# frames = torch.from_numpy(frames.asnumpy())
241
244
# frames = einops.rearrange(frames, "b h w c -> b c h w")
242
- # return frames
245
+ # return frames
0 commit comments