Skip to content

[WIP] This should work but it doesn't #472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 70 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
16d698f
Start implementation of approximate mode
scotts Dec 16, 2024
9a5abce
Merge branch 'main' of github.com:pytorch/torchcodec into approx
scotts Dec 18, 2024
d95b128
Initial seek mode implementation in VideoDecoder.
scotts Dec 19, 2024
97ac764
Merge branch 'main' of github.com:pytorch/torchcodec into approx
scotts Dec 19, 2024
35f2e59
Added Python side support, extended tests.
scotts Dec 20, 2024
8c9aeac
Apply lints
scotts Dec 20, 2024
921b822
Default C++ tests to approximate mode
scotts Dec 20, 2024
b349282
Apply lints
scotts Dec 20, 2024
081a5bb
Updated metadata; all tests pass.
scotts Dec 20, 2024
802b881
Removed commened out code.
scotts Dec 20, 2024
911a3bc
Consolidated logic for timestamp batch. Big perf win.
scotts Dec 20, 2024
7267b5a
Consolidated logic for timestamp range.
scotts Dec 21, 2024
ae44f78
More mode consolidation.
scotts Dec 21, 2024
e4edaf4
Merge branch 'main' of github.com:pytorch/torchcodec into approx
scotts Jan 9, 2025
64ebefe
Merge branch 'main' into approx
NicolasHug Jan 10, 2025
f62af46
Provide constructor param names
scotts Jan 11, 2025
d15dfa0
getFramesSize -> getNumFrames
scotts Jan 11, 2025
8879c32
Use seek_mode to paramterize metadata tests
scotts Jan 11, 2025
8443de5
stream -> streamInfo
scotts Jan 11, 2025
446910d
Merge branch 'main' of github.com:pytorch/torchcodec into approx
scotts Jan 11, 2025
ebebb63
Merge branch 'approx' of github.com:scotts/torchcodec into approx
scotts Jan 11, 2025
e34ca31
seek -> seekMode
scotts Jan 11, 2025
67c1225
remove getFramePlayedAtTimestampNoDemuxInternal
scotts Jan 11, 2025
d13879d
Merge branch 'main' of github.com:pytorch/torchcodec into approx
scotts Jan 14, 2025
06bb2c3
Rationalize time based samplers and valid metadata
scotts Jan 14, 2025
8ed3c5e
Refactor setting and using scanned number of frames
scotts Jan 14, 2025
bc10db8
Tweak FrameInfo struct initialization
scotts Jan 14, 2025
a6a2b6a
Use validateFrameIndex in getFramesInIndices
scotts Jan 14, 2025
3cd6842
Update src/torchcodec/decoders/_video_decoder.py
scotts Jan 14, 2025
9a46404
Tweak VideoDecoder doc string
scotts Jan 14, 2025
f4c001b
Remove comment
scotts Jan 15, 2025
4574e95
Merge branch 'main' of github.com:pytorch/torchcodec into approx
scotts Jan 17, 2025
abad57b
FrameInfo struct initialization
scotts Jan 17, 2025
737e1b6
Remove explicit setting of seek_mode in unrelated tests
scotts Jan 21, 2025
99b0d4f
Merge branch 'main' of github.com:pytorch/torchcodec into approx
scotts Jan 21, 2025
32a0f8f
Handle stream names
NicolasHug Jan 22, 2025
64f4595
Handle frame names
NicolasHug Jan 22, 2025
c629073
Streams again
NicolasHug Jan 22, 2025
85569fb
metadata names
NicolasHug Jan 22, 2025
d914615
Handle options
NicolasHug Jan 22, 2025
621a64c
Frame again
NicolasHug Jan 22, 2025
f247884
Lint
NicolasHug Jan 22, 2025
55a5840
More videoStreamOptions
NicolasHug Jan 22, 2025
793c876
Merge branch 'main' of github.com:pytorch/torchcodec into renamingzzzz
NicolasHug Jan 22, 2025
79ea167
reduce diff
NicolasHug Jan 22, 2025
404b2e4
Fix C++ tests
NicolasHug Jan 22, 2025
941d6a3
Fix CUDA?
NicolasHug Jan 22, 2025
11779a7
Merge branch 'main' of github.com:pytorch/torchcodec into renamingzzzz
NicolasHug Jan 23, 2025
a7c5711
Use allStreamMetadata
NicolasHug Jan 23, 2025
16d5e52
Rename BatchDecodedOutput into FrameBatchOutput
NicolasHug Jan 23, 2025
78b095a
Rename RawDecodedOutput into AVFrameWithStreamIndex
NicolasHug Jan 23, 2025
e1e46ff
Rename DecodedOutput into FrameOutput
NicolasHug Jan 23, 2025
cd0a181
Rename .frames into .data
NicolasHug Jan 23, 2025
0111bfc
rename .frame to .data
NicolasHug Jan 23, 2025
42ea096
Use frameBatchOutput variable name
NicolasHug Jan 23, 2025
3a2ab2d
rename getFrameOutputWithFilter
NicolasHug Jan 23, 2025
05da318
Use frameOutput name
NicolasHug Jan 23, 2025
a6a47f1
lint
NicolasHug Jan 23, 2025
d1352fe
Merge branch 'main' of github.com:pytorch/torchcodec into framezzzzz
NicolasHug Jan 23, 2025
658b727
getNextAVFrameNoDemux
NicolasHug Jan 23, 2025
62818a7
more stuff
NicolasHug Jan 23, 2025
24d0035
use avFrameWithStreamIndex
NicolasHug Jan 23, 2025
a1805d1
Cpp tests
NicolasHug Jan 23, 2025
eb1773a
more
NicolasHug Jan 23, 2025
21b8ff2
Remove getNextAVFrameNoDemux
NicolasHug Jan 23, 2025
dac4463
Rename maybeDesiredPts_ into desiredPts_
NicolasHug Jan 23, 2025
bc024eb
Merge branch 'desiredPTs' into aelfjnalefjnalfjenljaenf
NicolasHug Jan 23, 2025
a17027c
this should work
NicolasHug Jan 23, 2025
738a262
Merge branch 'main' of github.com:pytorch/torchcodec into this_should…
NicolasHug Jan 23, 2025
86c6ffd
Try
NicolasHug Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor setting and using scanned number of frames
  • Loading branch information
scotts committed Jan 14, 2025
commit 8ed3c5e5c4df2b041ad7c39a6c6e6b46ac46dd53
55 changes: 33 additions & 22 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,41 +570,51 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
if (scannedAllStreams_) {
return;
}

while (true) {
// Get the next packet.
UniqueAVPacket packet(av_packet_alloc());
int ffmpegStatus = av_read_frame(formatContext_.get(), packet.get());

if (ffmpegStatus == AVERROR_EOF) {
break;
}

if (ffmpegStatus != AVSUCCESS) {
throw std::runtime_error(
"Failed to read frame from input file: " +
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
}
int streamIndex = packet->stream_index;

if (packet->flags & AV_PKT_FLAG_DISCARD) {
continue;
}
auto& stream = containerMetadata_.streams[streamIndex];
stream.minPtsFromScan =
std::min(stream.minPtsFromScan.value_or(INT64_MAX), packet->pts);
stream.maxPtsFromScan = std::max(
stream.maxPtsFromScan.value_or(INT64_MIN),
packet->pts + packet->duration);
stream.numFramesFromScan = stream.numFramesFromScan.value_or(0) + 1;

FrameInfo frameInfo;
frameInfo.pts = packet->pts;
// We got a valid packet. Let's figure out what stream it belongs to and
// record its relevant metadata.
int streamIndex = packet->stream_index;
auto& streamMetadata = containerMetadata_.streams[streamIndex];
streamMetadata.minPtsFromScan = std::min(
streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts);
streamMetadata.maxPtsFromScan = std::max(
streamMetadata.maxPtsFromScan.value_or(INT64_MIN),
packet->pts + packet->duration);

FrameInfo frameInfo{.pts = packet->pts};
if (packet->flags & AV_PKT_FLAG_KEY) {
streams_[streamIndex].keyFrames.push_back(frameInfo);
}
streams_[streamIndex].allFrames.push_back(frameInfo);
}

// Set all per-stream metadata that requires knowing the content of all
// packets.
for (int i = 0; i < containerMetadata_.streams.size(); ++i) {
auto& streamMetadata = containerMetadata_.streams[i];
auto stream = formatContext_->streams[i];

streamMetadata.numFramesFromScan = streams_[i].allFrames.size();

if (streamMetadata.minPtsFromScan.has_value()) {
streamMetadata.minPtsSecondsFromScan =
*streamMetadata.minPtsFromScan * av_q2d(stream->time_base);
Expand All @@ -614,13 +624,17 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
*streamMetadata.maxPtsFromScan * av_q2d(stream->time_base);
}
}

// Reset the seek-cursor back to the beginning.
int ffmepgStatus =
avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0);
if (ffmepgStatus < 0) {
throw std::runtime_error(
"Could not seek file to pts=0: " +
getFFMPEGErrorStringFromErrorCode(ffmepgStatus));
}

// Sort all frames by their pts.
for (auto& [streamIndex, stream] : streams_) {
std::sort(
stream.keyFrames.begin(),
Expand All @@ -641,6 +655,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
}
}
}

scannedAllStreams_ = true;
}

Expand Down Expand Up @@ -1098,14 +1113,13 @@ void VideoDecoder::validateScannedAllStreams(const std::string& msg) {
}

void VideoDecoder::validateFrameIndex(
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata,
int64_t frameIndex) {
int64_t numFrames = getNumFrames(streamInfo, streamMetadata);
int64_t numFrames = getNumFrames(streamMetadata);
TORCH_CHECK(
frameIndex >= 0 && frameIndex < numFrames,
"Invalid frame index=" + std::to_string(frameIndex) +
" for streamIndex=" + std::to_string(streamInfo.streamIndex) +
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
" numFrames=" + std::to_string(numFrames));
}

Expand All @@ -1132,12 +1146,10 @@ int64_t VideoDecoder::getPts(
}
}

int64_t VideoDecoder::getNumFrames(
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata) {
int64_t VideoDecoder::getNumFrames(const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::exact:
return streamInfo.allFrames.size();
return streamMetadata.numFramesFromScan.value();
case SeekMode::approximate:
return streamMetadata.numFrames.value();
default:
Expand Down Expand Up @@ -1221,7 +1233,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal(

const auto& streamInfo = streams_[streamIndex];
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
validateFrameIndex(streamInfo, streamMetadata, frameIndex);
validateFrameIndex(streamMetadata, frameIndex);

int64_t pts = getPts(streamInfo, streamMetadata, frameIndex);
setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase));
Expand Down Expand Up @@ -1261,8 +1273,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
for (auto f = 0; f < frameIndices.size(); ++f) {
auto indexInOutput = indicesAreSorted ? f : argsort[f];
auto indexInVideo = frameIndices[indexInOutput];
if (indexInVideo < 0 ||
indexInVideo >= getNumFrames(stream, streamMetadata)) {
if (indexInVideo < 0 || indexInVideo >= getNumFrames(streamMetadata)) {
throw std::runtime_error(
"Invalid frame index=" + std::to_string(indexInVideo));
}
Expand Down Expand Up @@ -1327,7 +1338,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(

const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& stream = streams_[streamIndex];
int64_t numFrames = getNumFrames(stream, streamMetadata);
int64_t numFrames = getNumFrames(streamMetadata);
TORCH_CHECK(
start >= 0, "Range start, " + std::to_string(start) + " is less than 0.");
TORCH_CHECK(
Expand Down Expand Up @@ -1476,7 +1487,7 @@ double VideoDecoder::getPtsSecondsForFrame(

const auto& streamInfo = streams_[streamIndex];
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
validateFrameIndex(streamInfo, streamMetadata, frameIndex);
validateFrameIndex(streamMetadata, frameIndex);

return ptsToSeconds(
streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase);
Expand Down
5 changes: 1 addition & 4 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ class VideoDecoder {
void validateUserProvidedStreamIndex(uint64_t streamIndex);
void validateScannedAllStreams(const std::string& msg);
void validateFrameIndex(
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata,
int64_t frameIndex);

Expand All @@ -384,9 +383,7 @@ class VideoDecoder {
int expectedOutputHeight,
int expectedOutputWidth);

int64_t getNumFrames(
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata);
int64_t getNumFrames(const StreamMetadata& streamMetadata);

int64_t getPts(
const StreamInfo& streamInfo,
Expand Down
Loading