Skip to content

AudioDecoder: specify desired num_channels #678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,11 @@ void AudioEncoder::encodeInnerLoop(
AV_SAMPLE_FMT_FLTP,
avCodecContext_->sample_fmt,
srcAVFrame->sample_rate, // No sample rate conversion
srcAVFrame->sample_rate));
srcAVFrame->sample_rate,
2 // TODO
));
}
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
convertedAVFrame = convertAudioAVFrameSamples(
swrContext_,
srcAVFrame,
avCodecContext_->sample_fmt,
Expand Down
19 changes: 14 additions & 5 deletions src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,26 @@ SwrContext* createSwrContext(
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
int desiredSampleRate,
int desiredNumChannels) {
SwrContext* swrContext = nullptr;
int status = AVSUCCESS;
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
AVChannelLayout layout = avCodecContext->ch_layout;
AVChannelLayout sourceLayout = avCodecContext->ch_layout;
AVChannelLayout desiredLayout;
if (desiredNumChannels == getNumChannels(avCodecContext)) {
status = av_channel_layout_copy(&desiredLayout, &sourceLayout);
TORCH_CHECK(status == AVSUCCESS, "TODO");
} else {
av_channel_layout_default(&desiredLayout, desiredNumChannels);
// TODO check validity of this call?
}
status = swr_alloc_set_opts2(
&swrContext,
&layout,
&desiredLayout,
desiredSampleFormat,
desiredSampleRate,
&layout,
&sourceLayout,
sourceSampleFormat,
sourceSampleRate,
0,
Expand Down Expand Up @@ -167,7 +176,7 @@ SwrContext* createSwrContext(
return swrContext;
}

UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
UniqueAVFrame convertAudioAVFrameSamples(
const UniqueSwrContext& swrContext,
const UniqueAVFrame& srcAVFrame,
AVSampleFormat desiredSampleFormat,
Expand Down
5 changes: 3 additions & 2 deletions src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,10 @@ SwrContext* createSwrContext(
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);
int desiredSampleRate,
int desiredNumChannels);

UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
UniqueAVFrame convertAudioAVFrameSamples(
const UniqueSwrContext& swrContext,
const UniqueAVFrame& srcAVFrame,
AVSampleFormat desiredSampleFormat,
Expand Down
21 changes: 14 additions & 7 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1355,9 +1355,14 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
int desiredSampleRate =
streamInfo.audioStreamOptions.sampleRate.value_or(sourceSampleRate);

int sourceNumChannels = getNumChannels(srcAVFrame);
int desiredNumChannels =
streamInfo.audioStreamOptions.numChannels.value_or(sourceNumChannels);

bool mustConvert =
(sourceSampleFormat != desiredSampleFormat ||
sourceSampleRate != desiredSampleRate);
sourceSampleRate != desiredSampleRate ||
sourceNumChannels != desiredNumChannels);

UniqueAVFrame convertedAVFrame;
if (mustConvert) {
Expand All @@ -1367,10 +1372,11 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
sourceSampleFormat,
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate));
desiredSampleRate,
desiredNumChannels));
}

convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
convertedAVFrame = convertAudioAVFrameSamples(
streamInfo.swrContext,
srcAVFrame,
desiredSampleFormat,
Expand All @@ -1389,15 +1395,15 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
av_get_sample_fmt_name(format));

auto numSamples = avFrame->nb_samples; // per channel
auto numChannels = getNumChannels(avFrame);

frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32);
frameOutput.data =
torch::empty({desiredNumChannels, numSamples}, torch::kFloat32);

if (numSamples > 0) {
uint8_t* outputChannelData =
static_cast<uint8_t*>(frameOutput.data.data_ptr());
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
for (auto channel = 0; channel < numChannels;
for (auto channel = 0; channel < desiredNumChannels;
++channel, outputChannelData += numBytesPerChannel) {
std::memcpy(
outputChannelData,
Expand All @@ -1424,7 +1430,8 @@ std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
return std::nullopt;
}

auto numChannels = getNumChannels(streamInfo.codecContext);
int numChannels = streamInfo.audioStreamOptions.numChannels.value_or(
getNumChannels(streamInfo.codecContext));
torch::Tensor lastSamples =
torch::empty({numChannels, numRemainingSamples}, torch::kFloat32);

Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/StreamOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct AudioStreamOptions {
AudioStreamOptions() {}

std::optional<int> sampleRate;
std::optional<int> numChannels;
};

} // namespace facebook::torchcodec
6 changes: 4 additions & 2 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()");
m.def(
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None) -> ()");
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()");
m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)");
m.def(
Expand Down Expand Up @@ -280,9 +280,11 @@ void add_video_stream(
void add_audio_stream(
at::Tensor& decoder,
std::optional<int64_t> stream_index = std::nullopt,
std::optional<int64_t> sample_rate = std::nullopt) {
std::optional<int64_t> sample_rate = std::nullopt,
std::optional<int64_t> num_channels = std::nullopt) {
AudioStreamOptions audioStreamOptions;
audioStreamOptions.sampleRate = sample_rate;
audioStreamOptions.numChannels = num_channels;

auto videoDecoder = unwrapTensorToGetDecoder(decoder);
videoDecoder->addAudioStream(stream_index.value_or(-1), audioStreamOptions);
Expand Down
2 changes: 2 additions & 0 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def add_audio_stream_abstract(
decoder: torch.Tensor,
*,
stream_index: Optional[int] = None,
sample_rate: Optional[int] = None,
num_channels: Optional[int] = None,
) -> None:
return

Expand Down
8 changes: 7 additions & 1 deletion src/torchcodec/decoders/_audio_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class AudioDecoder:
the :term:`best stream` is used.
sample_rate (int, optional): The desired output sample rate of the decoded samples.
By default, the samples are returned in their original sample rate.
num_channels (int, optional): The desired number of channels of the decoded samples.
By default, the original number of channels is used.

Attributes:
metadata (AudioStreamMetadata): Metadata of the audio stream.
Expand All @@ -54,11 +56,15 @@ def __init__(
*,
stream_index: Optional[int] = None,
sample_rate: Optional[int] = None,
num_channels: Optional[int] = None,
):
self._decoder = create_decoder(source=source, seek_mode="approximate")

core.add_audio_stream(
self._decoder, stream_index=stream_index, sample_rate=sample_rate
self._decoder,
stream_index=stream_index,
sample_rate=sample_rate,
num_channels=num_channels,
)

container_metadata = core.get_container_metadata(self._decoder)
Expand Down
Loading